syncthing/client/client.go

279 lines
5.3 KiB
Go
Raw Normal View History

2015-06-27 17:52:01 -07:00
// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
package client
import (
"crypto/tls"
"fmt"
"log"
"net"
"net/url"
"time"
syncthingprotocol "github.com/syncthing/protocol"
"github.com/syncthing/relaysrv/protocol"
2015-07-22 14:34:05 -07:00
"github.com/syncthing/syncthing/internal/sync"
2015-06-27 17:52:01 -07:00
)
type ProtocolClient struct {
URI *url.URL
Invitations chan protocol.SessionInvitation
closeInvitationsOnFinish bool
config *tls.Config
timeout time.Duration
stop chan struct{}
stopped chan struct{}
conn *tls.Conn
2015-07-22 14:34:05 -07:00
mut sync.RWMutex
connected bool
2015-06-27 17:52:01 -07:00
}
2015-07-20 02:56:10 -07:00
func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) *ProtocolClient {
closeInvitationsOnFinish := false
if invitations == nil {
closeInvitationsOnFinish = true
invitations = make(chan protocol.SessionInvitation)
2015-06-28 12:34:28 -07:00
}
2015-07-20 02:56:10 -07:00
return &ProtocolClient{
URI: uri,
Invitations: invitations,
2015-06-27 17:52:01 -07:00
2015-07-20 02:56:10 -07:00
closeInvitationsOnFinish: closeInvitationsOnFinish,
2015-06-27 17:52:01 -07:00
2015-07-20 02:56:10 -07:00
config: configForCerts(certs),
2015-06-27 17:52:01 -07:00
2015-07-20 02:56:10 -07:00
timeout: time.Minute * 2,
stop: make(chan struct{}),
stopped: make(chan struct{}),
2015-07-22 14:34:05 -07:00
mut: sync.NewRWMutex(),
connected: false,
2015-07-20 02:56:10 -07:00
}
2015-06-27 17:52:01 -07:00
}
func (c *ProtocolClient) Serve() {
2015-07-20 02:56:10 -07:00
c.stop = make(chan struct{})
c.stopped = make(chan struct{})
defer close(c.stopped)
2015-06-27 17:52:01 -07:00
if err := c.connect(); err != nil {
2015-07-20 02:56:10 -07:00
l.Infoln("Relay connect:", err)
return
2015-06-27 17:52:01 -07:00
}
if debug {
l.Debugln(c, "connected", c.conn.RemoteAddr())
}
if err := c.join(); err != nil {
c.conn.Close()
2015-07-20 02:56:10 -07:00
l.Infoln("Relay join:", err)
return
2015-06-27 17:52:01 -07:00
}
2015-07-20 02:56:10 -07:00
if err := c.conn.SetDeadline(time.Time{}); err != nil {
l.Infoln("Relay set deadline:", err)
return
}
2015-06-27 17:52:01 -07:00
if debug {
l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr())
}
defer c.cleanup()
2015-07-22 14:34:05 -07:00
c.mut.Lock()
c.connected = true
c.mut.Unlock()
2015-06-27 17:52:01 -07:00
messages := make(chan interface{})
errors := make(chan error, 1)
2015-07-20 02:56:10 -07:00
go messageReader(c.conn, messages, errors)
2015-06-27 17:52:01 -07:00
timeout := time.NewTimer(c.timeout)
2015-07-20 02:56:10 -07:00
2015-06-27 17:52:01 -07:00
for {
select {
case message := <-messages:
timeout.Reset(c.timeout)
if debug {
log.Printf("%s received message %T", c, message)
}
2015-07-20 02:56:10 -07:00
2015-06-27 17:52:01 -07:00
switch msg := message.(type) {
case protocol.Ping:
if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil {
2015-07-20 02:56:10 -07:00
l.Infoln("Relay write:", err)
return
2015-06-27 17:52:01 -07:00
}
if debug {
l.Debugln(c, "sent pong")
}
2015-07-20 02:56:10 -07:00
2015-06-27 17:52:01 -07:00
case protocol.SessionInvitation:
ip := net.IP(msg.Address)
if len(ip) == 0 || ip.IsUnspecified() {
msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:]
}
c.Invitations <- msg
2015-07-20 02:56:10 -07:00
2015-06-27 17:52:01 -07:00
default:
2015-07-20 02:56:10 -07:00
l.Infoln("Relay: protocol error: unexpected message %v", msg)
return
2015-06-27 17:52:01 -07:00
}
2015-07-20 02:56:10 -07:00
2015-06-27 17:52:01 -07:00
case <-c.stop:
if debug {
l.Debugln(c, "stopping")
}
2015-07-20 02:56:10 -07:00
return
2015-06-27 17:52:01 -07:00
case err := <-errors:
2015-07-20 02:56:10 -07:00
l.Infoln("Relay received:", err)
return
2015-06-27 17:52:01 -07:00
case <-timeout.C:
if debug {
l.Debugln(c, "timed out")
}
return
}
}
}
func (c *ProtocolClient) Stop() {
if c.stop == nil {
return
}
2015-07-20 02:56:10 -07:00
close(c.stop)
2015-06-27 17:52:01 -07:00
<-c.stopped
}
2015-07-22 14:34:05 -07:00
func (c *ProtocolClient) StatusOK() bool {
c.mut.RLock()
con := c.connected
c.mut.RUnlock()
return con
}
2015-06-27 17:52:01 -07:00
func (c *ProtocolClient) String() string {
return fmt.Sprintf("ProtocolClient@%p", c)
}
2015-07-20 02:56:10 -07:00
func (c *ProtocolClient) connect() error {
if c.URI.Scheme != "relay" {
return fmt.Errorf("Unsupported relay schema:", c.URI.Scheme)
}
conn, err := tls.Dial("tcp", c.URI.Host, c.config)
if err != nil {
return err
}
if err := conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
conn.Close()
return err
}
if err := performHandshakeAndValidation(conn, c.URI); err != nil {
conn.Close()
return err
}
c.conn = conn
return nil
}
2015-06-27 17:52:01 -07:00
func (c *ProtocolClient) cleanup() {
if c.closeInvitationsOnFinish {
close(c.Invitations)
c.Invitations = make(chan protocol.SessionInvitation)
}
if debug {
l.Debugln(c, "cleaning up")
}
2015-07-22 14:34:05 -07:00
c.mut.Lock()
c.connected = false
c.mut.Unlock()
2015-07-20 02:56:10 -07:00
c.conn.Close()
2015-06-27 17:52:01 -07:00
}
func (c *ProtocolClient) join() error {
2015-07-20 02:56:10 -07:00
if err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}); err != nil {
2015-06-27 17:52:01 -07:00
return err
}
message, err := protocol.ReadMessage(c.conn)
if err != nil {
return err
}
switch msg := message.(type) {
case protocol.Response:
if msg.Code != 0 {
return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message)
}
2015-07-20 02:56:10 -07:00
2015-06-27 17:52:01 -07:00
default:
return fmt.Errorf("protocol error: expecting response got %v", msg)
}
return nil
}
func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error {
2015-07-20 02:56:10 -07:00
if err := conn.Handshake(); err != nil {
2015-06-27 17:52:01 -07:00
return err
}
cs := conn.ConnectionState()
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName {
return fmt.Errorf("protocol negotiation error")
}
q := uri.Query()
relayIDs := q.Get("id")
if relayIDs != "" {
relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs)
if err != nil {
return fmt.Errorf("relay address contains invalid verification id: %s", err)
}
certs := cs.PeerCertificates
if cl := len(certs); cl != 1 {
return fmt.Errorf("unexpected certificate count: %d", cl)
}
remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw)
if remoteID != relayID {
return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID)
}
}
return nil
}
2015-07-20 02:56:10 -07:00
func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) {
for {
msg, err := protocol.ReadMessage(conn)
if err != nil {
errors <- err
return
}
messages <- msg
}
}