From f8bedc55e5da8ca23cbe249aa13ba3a228577b68 Mon Sep 17 00:00:00 2001 From: Audrius Butkevicius Date: Sun, 28 Jun 2015 01:52:01 +0100 Subject: [PATCH] Progress --- cmd/relaysrv/README.md | 6 + cmd/relaysrv/client/client.go | 249 +++++++++++++++++++++++++++ cmd/relaysrv/client/debug.go | 15 ++ cmd/relaysrv/client/methods.go | 113 ++++++++++++ cmd/relaysrv/main.go | 39 +++-- cmd/relaysrv/protocol/packets.go | 42 ++--- cmd/relaysrv/protocol/packets_xdr.go | 216 +++++++++++++++++++---- cmd/relaysrv/protocol/protocol.go | 114 ++++++++++++ cmd/relaysrv/protocol_listener.go | 233 +++++++++++-------------- cmd/relaysrv/session.go | 179 ++++++++++--------- cmd/relaysrv/session_listener.go | 58 +++++-- cmd/relaysrv/testutil/main.go | 142 +++++++++++++++ cmd/relaysrv/utils.go | 27 +-- 13 files changed, 1114 insertions(+), 319 deletions(-) create mode 100644 cmd/relaysrv/README.md create mode 100644 cmd/relaysrv/client/client.go create mode 100644 cmd/relaysrv/client/debug.go create mode 100644 cmd/relaysrv/client/methods.go create mode 100644 cmd/relaysrv/protocol/protocol.go create mode 100644 cmd/relaysrv/testutil/main.go diff --git a/cmd/relaysrv/README.md b/cmd/relaysrv/README.md new file mode 100644 index 000000000..e88929280 --- /dev/null +++ b/cmd/relaysrv/README.md @@ -0,0 +1,6 @@ +relaysrv +======== + +This is the relay server for the `syncthing` project. + +`go get github.com/syncthing/relaysrv` diff --git a/cmd/relaysrv/client/client.go b/cmd/relaysrv/client/client.go new file mode 100644 index 000000000..b48320fd2 --- /dev/null +++ b/cmd/relaysrv/client/client.go @@ -0,0 +1,249 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package client + +import ( + "crypto/tls" + "fmt" + "log" + "net" + "net/url" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/relaysrv/protocol" +) + +func NewProtocolClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation) ProtocolClient { + closeInvitationsOnFinish := false + if invitations == nil { + closeInvitationsOnFinish = true + invitations = make(chan protocol.SessionInvitation) + } + return ProtocolClient{ + URI: uri, + Invitations: invitations, + + closeInvitationsOnFinish: closeInvitationsOnFinish, + + config: configForCerts(certs), + + timeout: time.Minute * 2, + + stop: make(chan struct{}), + stopped: make(chan struct{}), + } +} + +type ProtocolClient struct { + URI *url.URL + Invitations chan protocol.SessionInvitation + + closeInvitationsOnFinish bool + + config *tls.Config + + timeout time.Duration + + stop chan struct{} + stopped chan struct{} + + conn *tls.Conn +} + +func (c *ProtocolClient) connect() error { + conn, err := tls.Dial("tcp", c.URI.Host, c.config) + if err != nil { + return err + } + + conn.SetDeadline(time.Now().Add(10 * time.Second)) + + if err := performHandshakeAndValidation(conn, c.URI); err != nil { + return err + } + + c.conn = conn + return nil +} + +func (c *ProtocolClient) Serve() { + if err := c.connect(); err != nil { + panic(err) + } + + if debug { + l.Debugln(c, "connected", c.conn.RemoteAddr()) + } + + if err := c.join(); err != nil { + c.conn.Close() + panic(err) + } + + c.conn.SetDeadline(time.Time{}) + + if debug { + l.Debugln(c, "joined", c.conn.RemoteAddr(), "via", c.conn.LocalAddr()) + } + + c.stop = make(chan struct{}) + c.stopped = make(chan struct{}) + + defer c.cleanup() + + messages := make(chan interface{}) + errors := make(chan error, 1) + + go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } + }(c.conn, messages, errors) + + timeout := time.NewTimer(c.timeout) + for { + select { + case message := <-messages: + timeout.Reset(c.timeout) + if debug { + log.Printf("%s received message %T", c, message) + } + switch msg := message.(type) { + case protocol.Ping: + if err := protocol.WriteMessage(c.conn, protocol.Pong{}); err != nil { + panic(err) + } + if debug { + l.Debugln(c, "sent pong") + } + case protocol.SessionInvitation: + ip := net.IP(msg.Address) + if len(ip) == 0 || ip.IsUnspecified() { + msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:] + } + c.Invitations <- msg + default: + panic(fmt.Errorf("protocol error: unexpected message %v", msg)) + } + case <-c.stop: + if debug { + l.Debugln(c, "stopping") + } + break + case err := <-errors: + panic(err) + case <-timeout.C: + if debug { + l.Debugln(c, "timed out") + } + return + } + } + + c.stopped <- struct{}{} +} + +func (c *ProtocolClient) Stop() { + if c.stop == nil { + return + } + + c.stop <- struct{}{} + <-c.stopped +} + +func (c *ProtocolClient) String() string { + return fmt.Sprintf("ProtocolClient@%p", c) +} + +func (c *ProtocolClient) cleanup() { + if c.closeInvitationsOnFinish { + close(c.Invitations) + c.Invitations = make(chan protocol.SessionInvitation) + } + + if debug { + l.Debugln(c, "cleaning up") + } + + if c.stop != nil { + close(c.stop) + c.stop = nil + } + + if c.stopped != nil { + close(c.stopped) + c.stopped = nil + } + + if c.conn != nil { + c.conn.Close() + } +} + +func (c *ProtocolClient) join() error { + err := protocol.WriteMessage(c.conn, protocol.JoinRelayRequest{}) + if err != nil { + return err + } + + message, err := protocol.ReadMessage(c.conn) + if err != nil { + return err + } + + switch msg := message.(type) { + case protocol.Response: + if msg.Code != 0 { + return fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) + } + default: + return fmt.Errorf("protocol error: expecting response got %v", msg) + } + + return nil +} + +func performHandshakeAndValidation(conn *tls.Conn, uri *url.URL) error { + err := conn.Handshake() + if err != nil { + conn.Close() + return err + } + + cs := conn.ConnectionState() + if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != protocol.ProtocolName { + conn.Close() + return fmt.Errorf("protocol negotiation error") + } + + q := uri.Query() + relayIDs := q.Get("id") + if relayIDs != "" { + relayID, err := syncthingprotocol.DeviceIDFromString(relayIDs) + if err != nil { + conn.Close() + return fmt.Errorf("relay address contains invalid verification id: %s", err) + } + + certs := cs.PeerCertificates + if cl := len(certs); cl != 1 { + conn.Close() + return fmt.Errorf("unexpected certificate count: %d", cl) + } + + remoteID := syncthingprotocol.NewDeviceID(certs[0].Raw) + if remoteID != relayID { + conn.Close() + return fmt.Errorf("relay id does not match. Expected %v got %v", relayID, remoteID) + } + } + + return nil +} diff --git a/cmd/relaysrv/client/debug.go b/cmd/relaysrv/client/debug.go new file mode 100644 index 000000000..4a3608dec --- /dev/null +++ b/cmd/relaysrv/client/debug.go @@ -0,0 +1,15 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package client + +import ( + "os" + "strings" + + "github.com/calmh/logger" +) + +var ( + debug = strings.Contains(os.Getenv("STTRACE"), "relay") || os.Getenv("STTRACE") == "all" + l = logger.DefaultLogger +) diff --git a/cmd/relaysrv/client/methods.go b/cmd/relaysrv/client/methods.go new file mode 100644 index 000000000..1d457e294 --- /dev/null +++ b/cmd/relaysrv/client/methods.go @@ -0,0 +1,113 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package client + +import ( + "crypto/tls" + "fmt" + "net" + "net/url" + "strconv" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/relaysrv/protocol" +) + +func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate) (protocol.SessionInvitation, error) { + conn, err := tls.Dial("tcp", uri.Host, configForCerts(certs)) + conn.SetDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + return protocol.SessionInvitation{}, err + } + + if err := performHandshakeAndValidation(conn, uri); err != nil { + return protocol.SessionInvitation{}, err + } + + defer conn.Close() + + request := protocol.ConnectRequest{ + ID: id[:], + } + + if err := protocol.WriteMessage(conn, request); err != nil { + return protocol.SessionInvitation{}, err + } + + message, err := protocol.ReadMessage(conn) + if err != nil { + return protocol.SessionInvitation{}, err + } + + switch msg := message.(type) { + case protocol.Response: + return protocol.SessionInvitation{}, fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) + case protocol.SessionInvitation: + if debug { + l.Debugln("Received invitation via", conn.LocalAddr()) + } + ip := net.IP(msg.Address) + if len(ip) == 0 || ip.IsUnspecified() { + msg.Address = conn.RemoteAddr().(*net.TCPAddr).IP[:] + } + return msg, nil + default: + return protocol.SessionInvitation{}, fmt.Errorf("protocol error: unexpected message %v", msg) + } +} + +func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) { + addr := net.JoinHostPort(net.IP(invitation.Address).String(), strconv.Itoa(int(invitation.Port))) + + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + + request := protocol.JoinSessionRequest{ + Key: invitation.Key, + } + + conn.SetDeadline(time.Now().Add(10 * time.Second)) + err = protocol.WriteMessage(conn, request) + if err != nil { + return nil, err + } + + message, err := protocol.ReadMessage(conn) + if err != nil { + return nil, err + } + + conn.SetDeadline(time.Time{}) + + switch msg := message.(type) { + case protocol.Response: + if msg.Code != 0 { + return nil, fmt.Errorf("Incorrect response code %d: %s", msg.Code, msg.Message) + } + return conn, nil + default: + return nil, fmt.Errorf("protocol error: expecting response got %v", msg) + } +} + +func configForCerts(certs []tls.Certificate) *tls.Config { + return &tls.Config{ + Certificates: certs, + NextProtos: []string{protocol.ProtocolName}, + ClientAuth: tls.RequestClientCert, + SessionTicketsDisabled: true, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, + } +} diff --git a/cmd/relaysrv/main.go b/cmd/relaysrv/main.go index 3c4d533ed..5ca060689 100644 --- a/cmd/relaysrv/main.go +++ b/cmd/relaysrv/main.go @@ -6,13 +6,13 @@ import ( "crypto/tls" "flag" "log" - "os" + "net" "path/filepath" - "sync" "time" - syncthingprotocol "github.com/syncthing/protocol" "github.com/syncthing/relaysrv/protocol" + + syncthingprotocol "github.com/syncthing/protocol" ) var ( @@ -26,26 +26,11 @@ var ( networkTimeout time.Duration pingInterval time.Duration messageTimeout time.Duration - - pingMessage message - - mut = sync.RWMutex{} - outbox = make(map[syncthingprotocol.DeviceID]chan message) ) func main() { var dir, extAddress string - pingPayload := protocol.Ping{}.MustMarshalXDR() - pingMessage = message{ - header: protocol.Header{ - Magic: protocol.Magic, - MessageType: protocol.MessageTypePing, - MessageLength: int32(len(pingPayload)), - }, - payload: pingPayload, - } - flag.StringVar(&listenProtocol, "protocol-listen", ":22067", "Protocol listen address") flag.StringVar(&listenSession, "session-listen", ":22068", "Session listen address") flag.StringVar(&extAddress, "external-address", "", "External address to advertise, defaults no IP and session-listen port, causing clients to use the remote IP from the protocol connection") @@ -54,7 +39,20 @@ func main() { flag.DurationVar(&pingInterval, "ping-interval", time.Minute, "How often pings are sent") flag.DurationVar(&messageTimeout, "message-timeout", time.Minute, "Maximum amount of time we wait for relevant messages to arrive") + if extAddress == "" { + extAddress = listenSession + } + + addr, err := net.ResolveTCPAddr("tcp", extAddress) + if err != nil { + log.Fatal(err) + } + + sessionAddress = addr.IP[:] + sessionPort = uint16(addr.Port) + flag.BoolVar(&debug, "debug", false, "Enable debug output") + flag.Parse() certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") @@ -80,7 +78,10 @@ func main() { }, } - log.SetOutput(os.Stdout) + id := syncthingprotocol.NewDeviceID(cert.Certificate[0]) + if debug { + log.Println("ID:", id) + } go sessionListener(listenSession) diff --git a/cmd/relaysrv/protocol/packets.go b/cmd/relaysrv/protocol/packets.go index 4675d1cf4..658bc536d 100644 --- a/cmd/relaysrv/protocol/packets.go +++ b/cmd/relaysrv/protocol/packets.go @@ -5,39 +5,41 @@ package protocol -import ( - "unsafe" -) - const ( - Magic = 0x9E79BC40 - HeaderSize = unsafe.Sizeof(&Header{}) - ProtocolName = "bep-relay" + messageTypePing int32 = iota + messageTypePong + messageTypeJoinRelayRequest + messageTypeJoinSessionRequest + messageTypeResponse + messageTypeConnectRequest + messageTypeSessionInvitation ) -const ( - MessageTypePing int32 = iota - MessageTypePong - MessageTypeJoinRequest - MessageTypeConnectRequest - MessageTypeSessionInvitation -) - -type Header struct { - Magic uint32 - MessageType int32 - MessageLength int32 +type header struct { + magic uint32 + messageType int32 + messageLength int32 } type Ping struct{} type Pong struct{} -type JoinRequest struct{} +type JoinRelayRequest struct{} + +type JoinSessionRequest struct { + Key []byte // max:32 +} + +type Response struct { + Code int32 + Message string +} type ConnectRequest struct { ID []byte // max:32 } type SessionInvitation struct { + From []byte // max:32 Key []byte // max:32 Address []byte // max:32 Port uint16 diff --git a/cmd/relaysrv/protocol/packets_xdr.go b/cmd/relaysrv/protocol/packets_xdr.go index ca547e007..f18e18c18 100644 --- a/cmd/relaysrv/protocol/packets_xdr.go +++ b/cmd/relaysrv/protocol/packets_xdr.go @@ -13,37 +13,37 @@ import ( /* -Header Structure: +header Structure: 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Magic | +| magic | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Message Type | +| message Type | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Message Length | +| message Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -struct Header { - unsigned int Magic; - int MessageType; - int MessageLength; +struct header { + unsigned int magic; + int messageType; + int messageLength; } */ -func (o Header) EncodeXDR(w io.Writer) (int, error) { +func (o header) EncodeXDR(w io.Writer) (int, error) { var xw = xdr.NewWriter(w) return o.EncodeXDRInto(xw) } -func (o Header) MarshalXDR() ([]byte, error) { +func (o header) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o Header) MustMarshalXDR() []byte { +func (o header) MustMarshalXDR() []byte { bs, err := o.MarshalXDR() if err != nil { panic(err) @@ -51,35 +51,35 @@ func (o Header) MustMarshalXDR() []byte { return bs } -func (o Header) AppendXDR(bs []byte) ([]byte, error) { +func (o header) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) _, err := o.EncodeXDRInto(xw) return []byte(aw), err } -func (o Header) EncodeXDRInto(xw *xdr.Writer) (int, error) { - xw.WriteUint32(o.Magic) - xw.WriteUint32(uint32(o.MessageType)) - xw.WriteUint32(uint32(o.MessageLength)) +func (o header) EncodeXDRInto(xw *xdr.Writer) (int, error) { + xw.WriteUint32(o.magic) + xw.WriteUint32(uint32(o.messageType)) + xw.WriteUint32(uint32(o.messageLength)) return xw.Tot(), xw.Error() } -func (o *Header) DecodeXDR(r io.Reader) error { +func (o *header) DecodeXDR(r io.Reader) error { xr := xdr.NewReader(r) return o.DecodeXDRFrom(xr) } -func (o *Header) UnmarshalXDR(bs []byte) error { +func (o *header) UnmarshalXDR(bs []byte) error { var br = bytes.NewReader(bs) var xr = xdr.NewReader(br) return o.DecodeXDRFrom(xr) } -func (o *Header) DecodeXDRFrom(xr *xdr.Reader) error { - o.Magic = xr.ReadUint32() - o.MessageType = int32(xr.ReadUint32()) - o.MessageLength = int32(xr.ReadUint32()) +func (o *header) DecodeXDRFrom(xr *xdr.Reader) error { + o.magic = xr.ReadUint32() + o.messageType = int32(xr.ReadUint32()) + o.messageLength = int32(xr.ReadUint32()) return xr.Error() } @@ -199,28 +199,28 @@ func (o *Pong) DecodeXDRFrom(xr *xdr.Reader) error { /* -JoinRequest Structure: +JoinRelayRequest Structure: 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -struct JoinRequest { +struct JoinRelayRequest { } */ -func (o JoinRequest) EncodeXDR(w io.Writer) (int, error) { +func (o JoinRelayRequest) EncodeXDR(w io.Writer) (int, error) { var xw = xdr.NewWriter(w) return o.EncodeXDRInto(xw) } -func (o JoinRequest) MarshalXDR() ([]byte, error) { +func (o JoinRelayRequest) MarshalXDR() ([]byte, error) { return o.AppendXDR(make([]byte, 0, 128)) } -func (o JoinRequest) MustMarshalXDR() []byte { +func (o JoinRelayRequest) MustMarshalXDR() []byte { bs, err := o.MarshalXDR() if err != nil { panic(err) @@ -228,29 +228,169 @@ func (o JoinRequest) MustMarshalXDR() []byte { return bs } -func (o JoinRequest) AppendXDR(bs []byte) ([]byte, error) { +func (o JoinRelayRequest) AppendXDR(bs []byte) ([]byte, error) { var aw = xdr.AppendWriter(bs) var xw = xdr.NewWriter(&aw) _, err := o.EncodeXDRInto(xw) return []byte(aw), err } -func (o JoinRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { +func (o JoinRelayRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { return xw.Tot(), xw.Error() } -func (o *JoinRequest) DecodeXDR(r io.Reader) error { +func (o *JoinRelayRequest) DecodeXDR(r io.Reader) error { xr := xdr.NewReader(r) return o.DecodeXDRFrom(xr) } -func (o *JoinRequest) UnmarshalXDR(bs []byte) error { +func (o *JoinRelayRequest) UnmarshalXDR(bs []byte) error { var br = bytes.NewReader(bs) var xr = xdr.NewReader(br) return o.DecodeXDRFrom(xr) } -func (o *JoinRequest) DecodeXDRFrom(xr *xdr.Reader) error { +func (o *JoinRelayRequest) DecodeXDRFrom(xr *xdr.Reader) error { + return xr.Error() +} + +/* + +JoinSessionRequest Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of Key | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Key (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct JoinSessionRequest { + opaque Key<32>; +} + +*/ + +func (o JoinSessionRequest) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o JoinSessionRequest) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o JoinSessionRequest) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o JoinSessionRequest) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o JoinSessionRequest) EncodeXDRInto(xw *xdr.Writer) (int, error) { + if l := len(o.Key); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32) + } + xw.WriteBytes(o.Key) + return xw.Tot(), xw.Error() +} + +func (o *JoinSessionRequest) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *JoinSessionRequest) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *JoinSessionRequest) DecodeXDRFrom(xr *xdr.Reader) error { + o.Key = xr.ReadBytesMax(32) + return xr.Error() +} + +/* + +Response Structure: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Code | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of Message | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ Message (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + +struct Response { + int Code; + string Message<>; +} + +*/ + +func (o Response) EncodeXDR(w io.Writer) (int, error) { + var xw = xdr.NewWriter(w) + return o.EncodeXDRInto(xw) +} + +func (o Response) MarshalXDR() ([]byte, error) { + return o.AppendXDR(make([]byte, 0, 128)) +} + +func (o Response) MustMarshalXDR() []byte { + bs, err := o.MarshalXDR() + if err != nil { + panic(err) + } + return bs +} + +func (o Response) AppendXDR(bs []byte) ([]byte, error) { + var aw = xdr.AppendWriter(bs) + var xw = xdr.NewWriter(&aw) + _, err := o.EncodeXDRInto(xw) + return []byte(aw), err +} + +func (o Response) EncodeXDRInto(xw *xdr.Writer) (int, error) { + xw.WriteUint32(uint32(o.Code)) + xw.WriteString(o.Message) + return xw.Tot(), xw.Error() +} + +func (o *Response) DecodeXDR(r io.Reader) error { + xr := xdr.NewReader(r) + return o.DecodeXDRFrom(xr) +} + +func (o *Response) UnmarshalXDR(bs []byte) error { + var br = bytes.NewReader(bs) + var xr = xdr.NewReader(br) + return o.DecodeXDRFrom(xr) +} + +func (o *Response) DecodeXDRFrom(xr *xdr.Reader) error { + o.Code = int32(xr.ReadUint32()) + o.Message = xr.ReadString() return xr.Error() } @@ -330,6 +470,12 @@ SessionInvitation Structure: 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Length of From | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ From (variable length) \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Length of Key | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / / @@ -349,6 +495,7 @@ SessionInvitation Structure: struct SessionInvitation { + opaque From<32>; opaque Key<32>; opaque Address<32>; unsigned int Port; @@ -382,6 +529,10 @@ func (o SessionInvitation) AppendXDR(bs []byte) ([]byte, error) { } func (o SessionInvitation) EncodeXDRInto(xw *xdr.Writer) (int, error) { + if l := len(o.From); l > 32 { + return xw.Tot(), xdr.ElementSizeExceeded("From", l, 32) + } + xw.WriteBytes(o.From) if l := len(o.Key); l > 32 { return xw.Tot(), xdr.ElementSizeExceeded("Key", l, 32) } @@ -407,6 +558,7 @@ func (o *SessionInvitation) UnmarshalXDR(bs []byte) error { } func (o *SessionInvitation) DecodeXDRFrom(xr *xdr.Reader) error { + o.From = xr.ReadBytesMax(32) o.Key = xr.ReadBytesMax(32) o.Address = xr.ReadBytesMax(32) o.Port = xr.ReadUint16() diff --git a/cmd/relaysrv/protocol/protocol.go b/cmd/relaysrv/protocol/protocol.go new file mode 100644 index 000000000..57a967ac8 --- /dev/null +++ b/cmd/relaysrv/protocol/protocol.go @@ -0,0 +1,114 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package protocol + +import ( + "fmt" + "io" +) + +const ( + magic = 0x9E79BC40 + ProtocolName = "bep-relay" +) + +var ( + ResponseSuccess = Response{0, "success"} + ResponseNotFound = Response{1, "not found"} + ResponseAlreadyConnected = Response{2, "already connected"} + ResponseInternalError = Response{99, "internal error"} + ResponseUnexpectedMessage = Response{100, "unexpected message"} +) + +func WriteMessage(w io.Writer, message interface{}) error { + header := header{ + magic: magic, + } + + var payload []byte + var err error + + switch msg := message.(type) { + case Ping: + payload, err = msg.MarshalXDR() + header.messageType = messageTypePing + case Pong: + payload, err = msg.MarshalXDR() + header.messageType = messageTypePong + case JoinRelayRequest: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeJoinRelayRequest + case JoinSessionRequest: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeJoinSessionRequest + case Response: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeResponse + case ConnectRequest: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeConnectRequest + case SessionInvitation: + payload, err = msg.MarshalXDR() + header.messageType = messageTypeSessionInvitation + default: + err = fmt.Errorf("Unknown message type") + } + + if err != nil { + return err + } + + header.messageLength = int32(len(payload)) + + headerpayload, err := header.MarshalXDR() + if err != nil { + return err + } + + _, err = w.Write(append(headerpayload, payload...)) + return err +} + +func ReadMessage(r io.Reader) (interface{}, error) { + var header header + if err := header.DecodeXDR(r); err != nil { + return nil, err + } + + if header.magic != magic { + return nil, fmt.Errorf("magic mismatch") + } + + switch header.messageType { + case messageTypePing: + var msg Ping + err := msg.DecodeXDR(r) + return msg, err + case messageTypePong: + var msg Pong + err := msg.DecodeXDR(r) + return msg, err + case messageTypeJoinRelayRequest: + var msg JoinRelayRequest + err := msg.DecodeXDR(r) + return msg, err + case messageTypeJoinSessionRequest: + var msg JoinSessionRequest + err := msg.DecodeXDR(r) + return msg, err + case messageTypeResponse: + var msg Response + err := msg.DecodeXDR(r) + return msg, err + case messageTypeConnectRequest: + var msg ConnectRequest + err := msg.DecodeXDR(r) + return msg, err + case messageTypeSessionInvitation: + var msg SessionInvitation + err := msg.DecodeXDR(r) + return msg, err + } + + return nil, fmt.Errorf("Unknown message type") +} diff --git a/cmd/relaysrv/protocol_listener.go b/cmd/relaysrv/protocol_listener.go index b6d89b226..1e18b156e 100644 --- a/cmd/relaysrv/protocol_listener.go +++ b/cmd/relaysrv/protocol_listener.go @@ -4,9 +4,9 @@ package main import ( "crypto/tls" - "io" "log" "net" + "sync" "time" syncthingprotocol "github.com/syncthing/protocol" @@ -14,10 +14,10 @@ import ( "github.com/syncthing/relaysrv/protocol" ) -type message struct { - header protocol.Header - payload []byte -} +var ( + outboxesMut = sync.RWMutex{} + outboxes = make(map[syncthingprotocol.DeviceID]chan interface{}) +) func protocolListener(addr string, config *tls.Config) { listener, err := net.Listen("tcp", addr) @@ -27,6 +27,7 @@ func protocolListener(addr string, config *tls.Config) { for { conn, err := listener.Accept() + setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -43,15 +44,12 @@ func protocolListener(addr string, config *tls.Config) { } func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { - err := setTCPOptions(tcpConn) - if err != nil && debug { - log.Println("Failed to set TCP options on protocol connection", tcpConn.RemoteAddr(), err) - } - conn := tls.Server(tcpConn, config) - err = conn.Handshake() + err := conn.Handshake() if err != nil { - log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err) + if debug { + log.Println("Protocol connection TLS handshake:", conn.RemoteAddr(), err) + } conn.Close() return } @@ -63,168 +61,147 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config) { certs := state.PeerCertificates if len(certs) != 1 { - log.Println("Certificate list error") + if debug { + log.Println("Certificate list error") + } conn.Close() return } - deviceId := syncthingprotocol.NewDeviceID(certs[0].Raw) + id := syncthingprotocol.NewDeviceID(certs[0].Raw) - mut.RLock() - _, ok := outbox[deviceId] - mut.RUnlock() - if ok { - log.Println("Already have a peer with the same ID", deviceId, conn.RemoteAddr()) - conn.Close() - return - } + messages := make(chan interface{}) + errors := make(chan error, 1) + outbox := make(chan interface{}) - errorChannel := make(chan error) - messageChannel := make(chan message) - outboxChannel := make(chan message) - - go readerLoop(conn, messageChannel, errorChannel) + go func(conn net.Conn, message chan<- interface{}, errors chan<- error) { + for { + msg, err := protocol.ReadMessage(conn) + if err != nil { + errors <- err + return + } + messages <- msg + } + }(conn, messages, errors) pingTicker := time.NewTicker(pingInterval) - timeoutTicker := time.NewTimer(messageTimeout * 2) + timeoutTicker := time.NewTimer(networkTimeout) joined := false for { select { - case msg := <-messageChannel: - switch msg.header.MessageType { - case protocol.MessageTypeJoinRequest: - mut.Lock() - outbox[deviceId] = outboxChannel - mut.Unlock() - joined = true - case protocol.MessageTypeConnectRequest: - // We will disconnect after this message, no matter what, - // because, we've either sent out an invitation, or we don't - // have the peer available. - var fmsg protocol.ConnectRequest - err := fmsg.UnmarshalXDR(msg.payload) - if err != nil { - log.Println(err) + case message := <-messages: + timeoutTicker.Reset(networkTimeout) + if debug { + log.Printf("Message %T from %s", message, id) + } + switch msg := message.(type) { + case protocol.JoinRelayRequest: + outboxesMut.RLock() + _, ok := outboxes[id] + outboxesMut.RUnlock() + if ok { + protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) + if debug { + log.Println("Already have a peer with the same ID", id, conn.RemoteAddr()) + } conn.Close() continue } - requestedPeer := syncthingprotocol.DeviceIDFromBytes(fmsg.ID) - mut.RLock() - peerOutbox, ok := outbox[requestedPeer] - mut.RUnlock() + outboxesMut.Lock() + outboxes[id] = outbox + outboxesMut.Unlock() + joined = true + + protocol.WriteMessage(conn, protocol.ResponseSuccess) + case protocol.ConnectRequest: + requestedPeer := syncthingprotocol.DeviceIDFromBytes(msg.ID) + outboxesMut.RLock() + peerOutbox, ok := outboxes[requestedPeer] + outboxesMut.RUnlock() if !ok { if debug { - log.Println("Do not have", requestedPeer) + log.Println(id, "is looking", requestedPeer, "which does not exist") } + protocol.WriteMessage(conn, protocol.ResponseNotFound) conn.Close() continue } ses := newSession() - smsg, err := ses.GetServerInvitationMessage() - if err != nil { - log.Println("Error getting server invitation", requestedPeer) - conn.Close() - continue - } - cmsg, err := ses.GetClientInvitationMessage() - if err != nil { - log.Println("Error getting client invitation", requestedPeer) - conn.Close() - continue - } - go ses.Serve() - if err := sendMessage(cmsg, conn); err != nil { - log.Println("Failed to send invitation message", err) - } else { - peerOutbox <- smsg + clientInvitation := ses.GetClientInvitationMessage(requestedPeer) + serverInvitation := ses.GetServerInvitationMessage(id) + + if err := protocol.WriteMessage(conn, clientInvitation); err != nil { if debug { - log.Println("Sent invitation from", deviceId, "to", requestedPeer) + log.Printf("Error sending invitation from %s to client: %s", id, err) } + conn.Close() + continue + } + + peerOutbox <- serverInvitation + + if debug { + log.Println("Sent invitation from", id, "to", requestedPeer) } conn.Close() - case protocol.MessageTypePong: - timeoutTicker.Reset(messageTimeout) + case protocol.Pong: + default: + if debug { + log.Printf("Unknown message %s: %T", id, message) + } + protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) + conn.Close() } - case err := <-errorChannel: - log.Println("Closing connection:", err) + case err := <-errors: + if debug { + log.Printf("Closing connection %s: %s", id, err) + } + // Potentially closing a second time. + close(outbox) + conn.Close() + outboxesMut.Lock() + delete(outboxes, id) + outboxesMut.Unlock() return case <-pingTicker.C: if !joined { - log.Println(deviceId, "didn't join within", messageTimeout) + if debug { + log.Println(id, "didn't join within", pingInterval) + } conn.Close() continue } - if err := sendMessage(pingMessage, conn); err != nil { - log.Println(err) + if err := protocol.WriteMessage(conn, protocol.Ping{}); err != nil { + if debug { + log.Println(id, err) + } conn.Close() - continue } case <-timeoutTicker.C: - // We should receive a error, which will cause us to quit the - // loop. - conn.Close() - case msg := <-outboxChannel: + // We should receive a error from the reader loop, which will cause + // us to quit this loop. if debug { - log.Println("Sending message to", deviceId, msg) + log.Printf("%s timed out", id) } - if err := sendMessage(msg, conn); err == nil { - log.Println(err) + conn.Close() + case msg := <-outbox: + if debug { + log.Printf("Sending message %T to %s", msg, id) + } + if err := protocol.WriteMessage(conn, msg); err != nil { + if debug { + log.Println(id, err) + } conn.Close() - continue } } } } - -func readerLoop(conn *tls.Conn, messages chan<- message, errors chan<- error) { - header := make([]byte, protocol.HeaderSize) - data := make([]byte, 0, 0) - for { - _, err := io.ReadFull(conn, header) - if err != nil { - errors <- err - conn.Close() - return - } - - var hdr protocol.Header - err = hdr.UnmarshalXDR(header) - if err != nil { - conn.Close() - return - } - - if hdr.Magic != protocol.Magic { - conn.Close() - return - } - - if hdr.MessageLength > int32(cap(data)) { - data = make([]byte, 0, hdr.MessageLength) - } else { - data = data[:hdr.MessageLength] - } - - _, err = io.ReadFull(conn, data) - if err != nil { - errors <- err - conn.Close() - return - } - - msg := message{ - header: hdr, - payload: make([]byte, hdr.MessageLength), - } - copy(msg.payload, data[:hdr.MessageLength]) - - messages <- msg - } -} diff --git a/cmd/relaysrv/session.go b/cmd/relaysrv/session.go index 3466bd535..c5a091952 100644 --- a/cmd/relaysrv/session.go +++ b/cmd/relaysrv/session.go @@ -4,23 +4,27 @@ package main import ( "crypto/rand" + "encoding/hex" + "fmt" + "log" "net" "sync" "time" "github.com/syncthing/relaysrv/protocol" + + syncthingprotocol "github.com/syncthing/protocol" ) var ( - sessionmut = sync.Mutex{} + sessionMut = sync.Mutex{} sessions = make(map[string]*session, 0) ) type session struct { - serverkey string - clientkey string + serverkey []byte + clientkey []byte - mut sync.RWMutex conns chan net.Conn } @@ -37,16 +41,27 @@ func newSession() *session { return nil } - return &session{ - serverkey: string(serverkey), - clientkey: string(clientkey), + ses := &session{ + serverkey: serverkey, + clientkey: clientkey, conns: make(chan net.Conn), } + + if debug { + log.Println("New session", ses) + } + + sessionMut.Lock() + sessions[string(ses.serverkey)] = ses + sessions[string(ses.clientkey)] = ses + sessionMut.Unlock() + + return ses } func findSession(key string) *session { - sessionmut.Lock() - defer sessionmut.Unlock() + sessionMut.Lock() + defer sessionMut.Unlock() lob, ok := sessions[key] if !ok { return nil @@ -56,118 +71,128 @@ func findSession(key string) *session { return lob } -func (l *session) AddConnection(conn net.Conn) { +func (s *session) AddConnection(conn net.Conn) bool { + if debug { + log.Println("New connection for", s, "from", conn.RemoteAddr()) + } + select { - case l.conns <- conn: + case s.conns <- conn: + return true default: } + return false } -func (l *session) Serve() { - +func (s *session) Serve() { timedout := time.After(messageTimeout) - sessionmut.Lock() - sessions[l.serverkey] = l - sessions[l.clientkey] = l - sessionmut.Unlock() + if debug { + log.Println("Session", s, "serving") + } conns := make([]net.Conn, 0, 2) for { select { - case conn := <-l.conns: + case conn := <-s.conns: conns = append(conns, conn) if len(conns) < 2 { continue } - close(l.conns) + close(s.conns) + + if debug { + log.Println("Session", s, "starting between", conns[0].RemoteAddr(), conns[1].RemoteAddr()) + } wg := sync.WaitGroup{} - wg.Add(2) - go proxy(conns[0], conns[1], wg) - go proxy(conns[1], conns[0], wg) + errors := make(chan error, 2) + + go func() { + errors <- proxy(conns[0], conns[1]) + wg.Done() + }() + + go func() { + errors <- proxy(conns[1], conns[0]) + wg.Done() + }() wg.Wait() - break - case <-timedout: - sessionmut.Lock() - delete(sessions, l.serverkey) - delete(sessions, l.clientkey) - sessionmut.Unlock() - - for _, conn := range conns { - conn.Close() + if debug { + log.Println("Session", s, "ended, outcomes:", <-errors, <-errors) } - - break + goto done + case <-timedout: + if debug { + log.Println("Session", s, "timed out") + } + goto done } } +done: + sessionMut.Lock() + delete(sessions, string(s.serverkey)) + delete(sessions, string(s.clientkey)) + sessionMut.Unlock() + + for _, conn := range conns { + conn.Close() + } + + if debug { + log.Println("Session", s, "stopping") + } } -func (l *session) GetClientInvitationMessage() (message, error) { - invitation := protocol.SessionInvitation{ - Key: []byte(l.clientkey), - Address: nil, - Port: 123, +func (s *session) GetClientInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { + return protocol.SessionInvitation{ + From: from[:], + Key: []byte(s.clientkey), + Address: sessionAddress, + Port: sessionPort, ServerSocket: false, } - data, err := invitation.MarshalXDR() - if err != nil { - return message{}, err - } - - return message{ - header: protocol.Header{ - Magic: protocol.Magic, - MessageType: protocol.MessageTypeSessionInvitation, - MessageLength: int32(len(data)), - }, - payload: data, - }, nil } -func (l *session) GetServerInvitationMessage() (message, error) { - invitation := protocol.SessionInvitation{ - Key: []byte(l.serverkey), - Address: nil, - Port: 123, +func (s *session) GetServerInvitationMessage(from syncthingprotocol.DeviceID) protocol.SessionInvitation { + return protocol.SessionInvitation{ + From: from[:], + Key: []byte(s.serverkey), + Address: sessionAddress, + Port: sessionPort, ServerSocket: true, } - data, err := invitation.MarshalXDR() - if err != nil { - return message{}, err - } - - return message{ - header: protocol.Header{ - Magic: protocol.Magic, - MessageType: protocol.MessageTypeSessionInvitation, - MessageLength: int32(len(data)), - }, - payload: data, - }, nil } -func proxy(c1, c2 net.Conn, wg sync.WaitGroup) { +func proxy(c1, c2 net.Conn) error { + if debug { + log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) + } + buf := make([]byte, 1024) for { - buf := make([]byte, 1024) c1.SetReadDeadline(time.Now().Add(networkTimeout)) - n, err := c1.Read(buf) + n, err := c1.Read(buf[0:]) if err != nil { - break + return err + } + + if debug { + log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr()) } c2.SetWriteDeadline(time.Now().Add(networkTimeout)) _, err = c2.Write(buf[:n]) if err != nil { - break + return err } } - c1.Close() - c2.Close() - wg.Done() +} + +func (s *session) String() string { + return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5]) } diff --git a/cmd/relaysrv/session_listener.go b/cmd/relaysrv/session_listener.go index b78c4f4b6..6159ceef5 100644 --- a/cmd/relaysrv/session_listener.go +++ b/cmd/relaysrv/session_listener.go @@ -3,10 +3,11 @@ package main import ( - "io" "log" "net" "time" + + "github.com/syncthing/relaysrv/protocol" ) func sessionListener(addr string) { @@ -17,6 +18,7 @@ func sessionListener(addr string) { for { conn, err := listener.Accept() + setTCPOptions(conn) if err != nil { if debug { log.Println(err) @@ -33,27 +35,49 @@ func sessionListener(addr string) { } func sessionConnectionHandler(conn net.Conn) { - conn.SetReadDeadline(time.Now().Add(messageTimeout)) - key := make([]byte, 32) - - _, err := io.ReadFull(conn, key) + conn.SetDeadline(time.Now().Add(messageTimeout)) + message, err := protocol.ReadMessage(conn) if err != nil { + conn.Close() + return + } + + switch msg := message.(type) { + case protocol.JoinSessionRequest: + ses := findSession(string(msg.Key)) if debug { - log.Println("Failed to read key", err, conn.RemoteAddr()) + log.Println(conn.RemoteAddr(), "session lookup", ses) } - conn.Close() - return - } - ses := findSession(string(key)) - if debug { - log.Println("Key", key, "by", conn.RemoteAddr(), "session", ses) - } + if ses == nil { + protocol.WriteMessage(conn, protocol.ResponseNotFound) + conn.Close() + return + } - if ses != nil { - ses.AddConnection(conn) - } else { + if !ses.AddConnection(conn) { + if debug { + log.Println("Failed to add", conn.RemoteAddr(), "to session", ses) + } + protocol.WriteMessage(conn, protocol.ResponseAlreadyConnected) + conn.Close() + return + } + + err := protocol.WriteMessage(conn, protocol.ResponseSuccess) + if err != nil { + if debug { + log.Println("Failed to send session join response to ", conn.RemoteAddr(), "for", ses) + } + conn.Close() + return + } + conn.SetDeadline(time.Time{}) + default: + if debug { + log.Println("Unexpected message from", conn.RemoteAddr(), message) + } + protocol.WriteMessage(conn, protocol.ResponseUnexpectedMessage) conn.Close() - return } } diff --git a/cmd/relaysrv/testutil/main.go b/cmd/relaysrv/testutil/main.go new file mode 100644 index 000000000..10c222457 --- /dev/null +++ b/cmd/relaysrv/testutil/main.go @@ -0,0 +1,142 @@ +// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file). + +package main + +import ( + "bufio" + "crypto/tls" + "flag" + "log" + "net" + "net/url" + "os" + "path/filepath" + "time" + + syncthingprotocol "github.com/syncthing/protocol" + "github.com/syncthing/relaysrv/client" + "github.com/syncthing/relaysrv/protocol" +) + +func main() { + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + + var connect, relay, dir string + var join bool + + flag.StringVar(&connect, "connect", "", "Device ID to which to connect to") + flag.BoolVar(&join, "join", false, "Join relay") + flag.StringVar(&relay, "relay", "relay://127.0.0.1:22067", "Relay address") + flag.StringVar(&dir, "keys", ".", "Directory where cert.pem and key.pem is stored") + + flag.Parse() + + certFile, keyFile := filepath.Join(dir, "cert.pem"), filepath.Join(dir, "key.pem") + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalln("Failed to load X509 key pair:", err) + } + + id := syncthingprotocol.NewDeviceID(cert.Certificate[0]) + log.Println("ID:", id) + + uri, err := url.Parse(relay) + if err != nil { + log.Fatal(err) + } + + stdin := make(chan string) + + go stdinReader(stdin) + + if join { + log.Printf("Creating client") + relay := client.NewProtocolClient(uri, []tls.Certificate{cert}, nil) + log.Printf("Created client") + + go relay.Serve() + + recv := make(chan protocol.SessionInvitation) + + go func() { + log.Println("Starting invitation receiver") + for invite := range relay.Invitations { + select { + case recv <- invite: + log.Printf("Received invitation from %s on %s:%d", syncthingprotocol.DeviceIDFromBytes(invite.From), net.IP(invite.Address), invite.Port) + default: + log.Printf("Discarding invitation", invite) + } + } + }() + + for { + conn, err := client.JoinSession(<-recv) + if err != nil { + log.Fatalln("Failed to join", err) + } + log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr()) + connectToStdio(stdin, conn) + log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr()) + } + } else if connect != "" { + id, err := syncthingprotocol.DeviceIDFromString(connect) + if err != nil { + log.Fatal(err) + } + + invite, err := client.GetInvitationFromRelay(uri, id, []tls.Certificate{cert}) + if err != nil { + log.Fatal(err) + } + + log.Printf("Received invitation from %s on %s:%d", syncthingprotocol.DeviceIDFromBytes(invite.From), net.IP(invite.Address), invite.Port) + conn, err := client.JoinSession(invite) + if err != nil { + log.Fatalln("Failed to join", err) + } + log.Println("Joined", conn.RemoteAddr(), conn.LocalAddr()) + connectToStdio(stdin, conn) + log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr()) + } else { + log.Fatal("Requires either join or connect") + } +} + +func stdinReader(c chan<- string) { + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + c <- scanner.Text() + c <- "\n" + } +} + +func connectToStdio(stdin <-chan string, conn net.Conn) { + go func() { + + }() + + buf := make([]byte, 1024) + for { + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + n, err := conn.Read(buf[0:]) + if err != nil { + nerr, ok := err.(net.Error) + if !ok || !nerr.Timeout() { + log.Println(err) + return + } + } + os.Stdout.Write(buf[:n]) + + select { + case msg := <-stdin: + _, err := conn.Write([]byte(msg)) + if err != nil { + return + } + default: + } + } +} diff --git a/cmd/relaysrv/utils.go b/cmd/relaysrv/utils.go index 5388ba32e..7d1f6bfa4 100644 --- a/cmd/relaysrv/utils.go +++ b/cmd/relaysrv/utils.go @@ -5,7 +5,6 @@ package main import ( "errors" "net" - "time" ) func setTCPOptions(conn net.Conn) error { @@ -19,7 +18,7 @@ func setTCPOptions(conn net.Conn) error { if err := tcpConn.SetNoDelay(true); err != nil { return err } - if err := tcpConn.SetKeepAlivePeriod(60 * time.Second); err != nil { + if err := tcpConn.SetKeepAlivePeriod(networkTimeout); err != nil { return err } if err := tcpConn.SetKeepAlive(true); err != nil { @@ -27,27 +26,3 @@ func setTCPOptions(conn net.Conn) error { } return nil } - -func sendMessage(msg message, conn net.Conn) error { - header, err := msg.header.MarshalXDR() - if err != nil { - return err - } - - err = conn.SetWriteDeadline(time.Now().Add(networkTimeout)) - if err != nil { - return err - } - - _, err = conn.Write(header) - if err != nil { - return err - } - - _, err = conn.Write(msg.payload) - if err != nil { - return err - } - - return nil -}