Merge pull request #12 from syncthing/ccreq

Enforce ClusterConfiguration at start, then no ordering
This commit is contained in:
Audrius Butkevicius 2015-06-26 14:44:56 +01:00
commit e19e2c123e
2 changed files with 46 additions and 29 deletions

View File

@ -31,8 +31,7 @@ const (
const ( const (
stateInitial = iota stateInitial = iota
stateCCRcvd stateReady
stateIdxRcvd
) )
// FileInfo flags // FileInfo flags
@ -103,7 +102,6 @@ type rawConnection struct {
id DeviceID id DeviceID
name string name string
receiver Model receiver Model
state int
cr *countingReader cr *countingReader
cw *countingWriter cw *countingWriter
@ -155,7 +153,6 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv
id: deviceID, id: deviceID,
name: name, name: name,
receiver: nativeModel{receiver}, receiver: nativeModel{receiver},
state: stateInitial,
cr: cr, cr: cr,
cw: cw, cw: cw,
outbox: make(chan hdrMsg), outbox: make(chan hdrMsg),
@ -285,6 +282,7 @@ func (c *rawConnection) readerLoop() (err error) {
c.close(err) c.close(err)
}() }()
state := stateInitial
for { for {
select { select {
case <-c.closed: case <-c.closed:
@ -298,47 +296,54 @@ func (c *rawConnection) readerLoop() (err error) {
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case ClusterConfigMessage:
if state != stateInitial {
return fmt.Errorf("protocol error: cluster config message in state %d", state)
}
go c.receiver.ClusterConfig(c.id, msg)
state = stateReady
case IndexMessage: case IndexMessage:
switch hdr.msgType { switch hdr.msgType {
case messageTypeIndex: case messageTypeIndex:
if c.state < stateCCRcvd { if state != stateReady {
return fmt.Errorf("protocol error: index message in state %d", c.state) return fmt.Errorf("protocol error: index message in state %d", state)
} }
c.handleIndex(msg) c.handleIndex(msg)
c.state = stateIdxRcvd state = stateReady
case messageTypeIndexUpdate: case messageTypeIndexUpdate:
if c.state < stateIdxRcvd { if state != stateReady {
return fmt.Errorf("protocol error: index update message in state %d", c.state) return fmt.Errorf("protocol error: index update message in state %d", state)
} }
c.handleIndexUpdate(msg) c.handleIndexUpdate(msg)
state = stateReady
} }
case RequestMessage: case RequestMessage:
if c.state < stateIdxRcvd { if state != stateReady {
return fmt.Errorf("protocol error: request message in state %d", c.state) return fmt.Errorf("protocol error: request message in state %d", state)
} }
// Requests are handled asynchronously // Requests are handled asynchronously
go c.handleRequest(hdr.msgID, msg) go c.handleRequest(hdr.msgID, msg)
case ResponseMessage: case ResponseMessage:
if c.state < stateIdxRcvd { if state != stateReady {
return fmt.Errorf("protocol error: response message in state %d", c.state) return fmt.Errorf("protocol error: response message in state %d", state)
} }
c.handleResponse(hdr.msgID, msg) c.handleResponse(hdr.msgID, msg)
case pingMessage: case pingMessage:
if state != stateReady {
return fmt.Errorf("protocol error: ping message in state %d", state)
}
c.send(hdr.msgID, messageTypePong, pongMessage{}) c.send(hdr.msgID, messageTypePong, pongMessage{})
case pongMessage: case pongMessage:
c.handlePong(hdr.msgID) if state != stateReady {
return fmt.Errorf("protocol error: pong message in state %d", state)
case ClusterConfigMessage:
if c.state != stateInitial {
return fmt.Errorf("protocol error: cluster config message in state %d", c.state)
} }
go c.receiver.ClusterConfig(c.id, msg) c.handlePong(hdr.msgID)
c.state = stateCCRcvd
case CloseMessage: case CloseMessage:
return errors.New(msg.Reason) return errors.New(msg.Reason)

View File

@ -67,8 +67,10 @@ func TestPing(t *testing.T) {
ar, aw := io.Pipe() ar, aw := io.Pipe()
br, bw := io.Pipe() br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
c1 := NewConnection(c1ID, br, aw, nil, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
c0.ClusterConfig(ClusterConfigMessage{})
c1.ClusterConfig(ClusterConfigMessage{})
if ok := c0.ping(); !ok { if ok := c0.ping(); !ok {
t.Error("c0 ping failed") t.Error("c0 ping failed")
@ -81,8 +83,8 @@ func TestPing(t *testing.T) {
func TestPingErr(t *testing.T) { func TestPingErr(t *testing.T) {
e := errors.New("something broke") e := errors.New("something broke")
for i := 0; i < 16; i++ { for i := 0; i < 32; i++ {
for j := 0; j < 16; j++ { for j := 0; j < 32; j++ {
m0 := newTestModel() m0 := newTestModel()
m1 := newTestModel() m1 := newTestModel()
@ -92,12 +94,16 @@ func TestPingErr(t *testing.T) {
ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e} ebw := &ErrPipe{PipeWriter: *bw, max: j, err: e}
c0 := NewConnection(c0ID, ar, ebw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) c0 := NewConnection(c0ID, ar, ebw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, eaw, m1, "name", CompressAlways) c1 := NewConnection(c1ID, br, eaw, m1, "name", CompressAlways)
c0.ClusterConfig(ClusterConfigMessage{})
c1.ClusterConfig(ClusterConfigMessage{})
res := c0.ping() res := c0.ping()
if (i < 8 || j < 8) && res { if (i < 8 || j < 8) && res {
// This should have resulted in failure, as there is no way an empty ClusterConfig plus a Ping message fits in eight bytes.
t.Errorf("Unexpected ping success; i=%d, j=%d", i, j) t.Errorf("Unexpected ping success; i=%d, j=%d", i, j)
} else if (i >= 12 && j >= 12) && !res { } else if (i >= 28 && j >= 28) && !res {
// This should have worked though, as 28 bytes is plenty for both.
t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j) t.Errorf("Unexpected ping fail; i=%d, j=%d", i, j)
} }
} }
@ -168,7 +174,9 @@ func TestVersionErr(t *testing.T) {
br, bw := io.Pipe() br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", CompressAlways) c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
c0.ClusterConfig(ClusterConfigMessage{})
c1.ClusterConfig(ClusterConfigMessage{})
w := xdr.NewWriter(c0.cw) w := xdr.NewWriter(c0.cw)
w.WriteUint32(encodeHeader(header{ w.WriteUint32(encodeHeader(header{
@ -191,7 +199,9 @@ func TestTypeErr(t *testing.T) {
br, bw := io.Pipe() br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", CompressAlways) c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
c0.ClusterConfig(ClusterConfigMessage{})
c1.ClusterConfig(ClusterConfigMessage{})
w := xdr.NewWriter(c0.cw) w := xdr.NewWriter(c0.cw)
w.WriteUint32(encodeHeader(header{ w.WriteUint32(encodeHeader(header{
@ -214,7 +224,9 @@ func TestClose(t *testing.T) {
br, bw := io.Pipe() br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection) c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
NewConnection(c1ID, br, aw, m1, "name", CompressAlways) c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
c0.ClusterConfig(ClusterConfigMessage{})
c1.ClusterConfig(ClusterConfigMessage{})
c0.close(nil) c0.close(nil)