diff --git a/client_test.go b/client_test.go index 84b55d0f..9c373c6e 100644 --- a/client_test.go +++ b/client_test.go @@ -50,7 +50,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} - connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} + connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) originalClientConnConstructor = newClientConnection tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tr := mocklogging.NewMockTracer(mockCtrl) @@ -518,7 +518,7 @@ var _ = Describe("Client", func() { manager.EXPECT().Add(connID, gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}} + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} c := make(chan struct{}) var cconn sendConn var version protocol.VersionNumber @@ -596,7 +596,7 @@ var _ = Describe("Client", func() { return conn } - config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}} + config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr("localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) @@ -605,14 +605,14 @@ var _ = Describe("Client", func() { }) }) -type mockedConnIDGenerator struct { +type mockConnIDGenerator struct { ConnID protocol.ConnectionID } -func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) { +func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) { return m.ConnID, nil } -func (m *mockedConnIDGenerator) ConnectionIDLen() int { +func (m *mockConnIDGenerator) ConnectionIDLen() int { return m.ConnID.Len() } diff --git a/conn_id_generator.go b/conn_id_generator.go index 0421d678..0a6aa855 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -14,7 +14,7 @@ type connIDGenerator struct { highestSeq uint64 activeSrcConnIDs map[uint64]protocol.ConnectionID - initialClientDestConnID protocol.ConnectionID + initialClientDestConnID *protocol.ConnectionID // nil for the client addConnectionID func(protocol.ConnectionID) getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken @@ -28,7 +28,7 @@ type connIDGenerator struct { func newConnIDGenerator( initialConnectionID protocol.ConnectionID, - initialClientDestConnID protocol.ConnectionID, // nil for the client + initialClientDestConnID *protocol.ConnectionID, // nil for the client addConnectionID func(protocol.ConnectionID), getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), @@ -117,14 +117,14 @@ func (m *connIDGenerator) issueNewConnID() error { func (m *connIDGenerator) SetHandshakeComplete() { if m.initialClientDestConnID != nil { - m.retireConnectionID(m.initialClientDestConnID) + m.retireConnectionID(*m.initialClientDestConnID) m.initialClientDestConnID = nil } } func (m *connIDGenerator) RemoveAll() { if m.initialClientDestConnID != nil { - m.removeConnectionID(m.initialClientDestConnID) + m.removeConnectionID(*m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { m.removeConnectionID(connID) @@ -134,7 +134,7 @@ func (m *connIDGenerator) RemoveAll() { func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { - connIDs = append(connIDs, m.initialClientDestConnID) + connIDs = append(connIDs, *m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 167a70d6..dc3a1223 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -20,11 +20,12 @@ var _ = Describe("Connection ID Generator", func() { queuedFrames []wire.Frame g *connIDGenerator ) - initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} - initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe} + initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) + initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe}) connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken { - return protocol.StatelessResetToken{c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0]} + b := c.Bytes()[0] + return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b} } BeforeEach(func() { @@ -35,7 +36,7 @@ var _ = Describe("Connection ID Generator", func() { replacedWithClosed = nil g = newConnIDGenerator( initialConnID, - initialClientDestConnID, + &initialClientDestConnID, func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) }, connIDToToken, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 6c849059..72d7d5af 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -16,7 +16,7 @@ var _ = Describe("Connection ID Manager", func() { tokenAdded *protocol.StatelessResetToken removedTokens []protocol.StatelessResetToken ) - initialConnID := protocol.ConnectionID{0, 0, 0, 0} + initialConnID := protocol.ParseConnectionID([]byte{0, 0, 0, 0}) BeforeEach(func() { frameQueue = nil @@ -34,7 +34,7 @@ var _ = Describe("Connection ID Manager", func() { get := func() (protocol.ConnectionID, protocol.StatelessResetToken) { if m.queue.Len() == 0 { - return nil, protocol.StatelessResetToken{} + return protocol.ConnectionID{}, protocol.StatelessResetToken{} } val := m.queue.Remove(m.queue.Front()) return val.ConnectionID, val.StatelessResetToken @@ -45,8 +45,8 @@ var _ = Describe("Connection ID Manager", func() { }) It("changes the initial connection ID", func() { - m.ChangeInitialConnID(protocol.ConnectionID{1, 2, 3, 4, 5}) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) + m.ChangeInitialConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) }) It("sets the token for the first connection ID", func() { @@ -59,81 +59,81 @@ var _ = Describe("Connection ID Manager", func() { It("adds and gets connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), StatelessResetToken: protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, })).To(Succeed()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 4, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, })).To(Succeed()) c1, rt1 := get() - Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(c1).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) Expect(rt1).To(Equal(protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) c2, rt2 := get() - Expect(c2).To(Equal(protocol.ConnectionID{2, 3, 4, 5})) + Expect(c2).To(Equal(protocol.ParseConnectionID([]byte{2, 3, 4, 5}))) Expect(rt2).To(Equal(protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})) c3, _ := get() - Expect(c3).To(BeNil()) + Expect(c3).To(BeZero()) }) It("accepts duplicates", func() { f1 := &wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } f2 := &wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } Expect(m.Add(f1)).To(Succeed()) Expect(m.Add(f2)).To(Succeed()) c1, rt1 := get() - Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(c1).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) Expect(rt1).To(Equal(protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) c2, _ := get() - Expect(c2).To(BeNil()) + Expect(c2).To(BeZero()) }) It("ignores duplicates for the currently used connection ID", func() { f := &wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } m.SetHandshakeComplete() Expect(m.Add(f)).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) c, _ := get() - Expect(c).To(BeNil()) + Expect(c).To(BeZero()) // Now send the same connection ID again. It should not be queued. Expect(m.Add(f)).To(Succeed()) c, _ = get() - Expect(c).To(BeNil()) + Expect(c).To(BeZero()) }) It("rejects duplicates with different connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), })).To(Succeed()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), })).To(MatchError("received conflicting connection IDs for sequence number 42")) }) It("rejects duplicates with different connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, })).To(Succeed()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, })).To(MatchError("received conflicting stateless reset tokens for sequence number 42")) }) @@ -141,29 +141,29 @@ var _ = Describe("Connection ID Manager", func() { It("retires connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), })).To(Succeed()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 13, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), })).To(Succeed()) Expect(frameQueue).To(BeEmpty()) Expect(m.Add(&wire.NewConnectionIDFrame{ RetirePriorTo: 14, SequenceNumber: 17, - ConnectionID: protocol.ConnectionID{3, 4, 5, 6}, + ConnectionID: protocol.ParseConnectionID([]byte{3, 4, 5, 6}), })).To(Succeed()) Expect(frameQueue).To(HaveLen(3)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(10)) Expect(frameQueue[1].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(13)) Expect(frameQueue[2].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{3, 4, 5, 6})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{3, 4, 5, 6}))) }) It("ignores reordered connection IDs, if their sequence number was already retired", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), RetirePriorTo: 5, })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) @@ -173,7 +173,7 @@ var _ = Describe("Connection ID Manager", func() { // Make sure it gets retired immediately now. Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 4, - ConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(4)) @@ -182,17 +182,17 @@ var _ = Describe("Connection ID Manager", func() { It("ignores reordered connection IDs, if their sequence number was already retired or less than active", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), RetirePriorTo: 5, })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) frameQueue = nil - Expect(m.Get()).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}))) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 9, - ConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), RetirePriorTo: 5, })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) @@ -200,7 +200,7 @@ var _ = Describe("Connection ID Manager", func() { }) It("accepts retransmissions for the connection ID that is in use", func() { - connID := protocol.ConnectionID{1, 2, 3, 4} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, @@ -225,13 +225,13 @@ var _ = Describe("Connection ID Manager", func() { for i := uint8(1); i < protocol.MaxActiveConnectionIDs; i++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(i), - ConnectionID: protocol.ConnectionID{i, i, i, i}, + ConnectionID: protocol.ParseConnectionID([]byte{i, i, i, i}), StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, })).To(Succeed()) } Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(9999), - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, })).To(MatchError(&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})) }) @@ -241,22 +241,22 @@ var _ = Describe("Connection ID Manager", func() { m.SetHandshakeComplete() Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) }) It("waits until handshake completion before initiating a connection ID update", func() { Expect(m.Get()).To(Equal(initialConnID)) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })).To(Succeed()) Expect(m.Get()).To(Equal(initialConnID)) m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) }) It("initiates subsequent updates when enough packets are sent", func() { @@ -264,14 +264,14 @@ var _ = Describe("Connection ID Manager", func() { for s = uint8(1); s < protocol.MaxActiveConnectionIDs; s++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, + ConnectionID: protocol.ParseConnectionID([]byte{s, s, s, s}), StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, })).To(Succeed()) } m.SetHandshakeComplete() lastConnID := m.Get() - Expect(lastConnID).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + Expect(lastConnID).To(Equal(protocol.ParseConnectionID([]byte{1, 1, 1, 1}))) var counter int for i := 0; i < 50*protocol.PacketsPerConnectionID; i++ { @@ -285,7 +285,7 @@ var _ = Describe("Connection ID Manager", func() { removedTokens = nil Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, + ConnectionID: protocol.ParseConnectionID([]byte{s, s, s, s}), StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, })).To(Succeed()) s++ @@ -298,28 +298,28 @@ var _ = Describe("Connection ID Manager", func() { for s := uint8(10); s <= 10+protocol.MaxActiveConnectionIDs/2; s++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, + ConnectionID: protocol.ParseConnectionID([]byte{s, s, s, s}), StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, })).To(Succeed()) } m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{10, 10, 10, 10})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{10, 10, 10, 10}))) for { m.SentPacket() - if m.Get().Equal(protocol.ConnectionID{11, 11, 11, 11}) { + if m.Get().Equal(protocol.ParseConnectionID([]byte{11, 11, 11, 11})) { break } } // The active conn ID is now {11, 11, 11, 11} - Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{12, 12, 12, 12})) + Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ParseConnectionID([]byte{12, 12, 12, 12}))) // Add a delayed connection ID. It should just be ignored now. frameQueue = nil Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(5), - ConnectionID: protocol.ConnectionID{5, 5, 5, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{5, 5, 5, 5}), StatelessResetToken: protocol.StatelessResetToken{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, })).To(Succeed()) - Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{12, 12, 12, 12})) + Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ParseConnectionID([]byte{12, 12, 12, 12}))) Expect(frameQueue).To(HaveLen(1)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(5)) }) @@ -328,21 +328,21 @@ var _ = Describe("Connection ID Manager", func() { for i := uint8(1); i <= protocol.MaxActiveConnectionIDs/2; i++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(i), - ConnectionID: protocol.ConnectionID{i, i, i, i}, + ConnectionID: protocol.ParseConnectionID([]byte{i, i, i, i}), StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, })).To(Succeed()) } m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 1, 1, 1}))) for i := 0; i < 2*protocol.PacketsPerConnectionID; i++ { m.SentPacket() } - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 1, 1, 1}))) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1337, - ConnectionID: protocol.ConnectionID{1, 3, 3, 7}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 3, 3, 7}), })).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{2, 2, 2, 2})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{2, 2, 2, 2}))) Expect(removedTokens).To(HaveLen(1)) Expect(removedTokens[0]).To(Equal(protocol.StatelessResetToken{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})) }) @@ -352,11 +352,11 @@ var _ = Describe("Connection ID Manager", func() { Expect(removedTokens).To(BeEmpty()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })).To(Succeed()) m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) m.Close() Expect(removedTokens).To(HaveLen(1)) Expect(removedTokens[0]).To(Equal(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})) diff --git a/connection.go b/connection.go index 25563ffc..72c5e652 100644 --- a/connection.go +++ b/connection.go @@ -261,7 +261,7 @@ var newConnection = func( logger: logger, version: v, } - if origDestConnID != nil { + if origDestConnID.Len() > 0 { s.logID = origDestConnID.String() } else { s.logID = destConnID.String() @@ -274,7 +274,7 @@ var newConnection = func( ) s.connIDGenerator = newConnIDGenerator( srcConnID, - clientDestConnID, + &clientDestConnID, func(connID protocol.ConnectionID) { runner.Add(connID, s) }, runner.GetStatelessResetToken, runner.Remove, diff --git a/connection_test.go b/connection_test.go index 609ebff4..3956616d 100644 --- a/connection_test.go +++ b/connection_test.go @@ -49,9 +49,9 @@ var _ = Describe("Connection", func() { ) remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331} - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) + clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) getPacket := func(pn protocol.PacketNumber) *packedPacket { buffer := getPacketBuffer() @@ -91,7 +91,7 @@ var _ = Describe("Connection", func() { conn = newConnection( mconn, connRunner, - nil, + protocol.ConnectionID{}, nil, clientDestConnID, destConnID, @@ -270,11 +270,12 @@ var _ = Describe("Connection", func() { }) It("handles NEW_CONNECTION_ID frames", func() { + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) Expect(conn.handleFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: connID, }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - Expect(conn.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(conn.connIDManager.queue.Back().Value.ConnectionID).To(Equal(connID)) }) It("handles PING frames", func() { @@ -664,8 +665,8 @@ var _ = Describe("Connection", func() { It("drops Version Negotiation packets", func() { b := wire.ComposeVersionNegotiation( - protocol.ArbitraryLenConnectionID(srcConnID), - protocol.ArbitraryLenConnectionID(destConnID), + protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), + protocol.ArbitraryLenConnectionID(destConnID.Bytes()), conn.config.Versions, ) tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) @@ -1051,7 +1052,7 @@ var _ = Describe("Connection", func() { IsLongHeader: true, Type: protocol.PacketTypeInitial, DestConnectionID: destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), Length: 1, Version: conn.version, }, @@ -1208,7 +1209,7 @@ var _ = Describe("Connection", func() { }) It("ignores coalesced packet parts if the destination connection IDs don't match", func() { - wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) Expect(srcConnID).ToNot(Equal(wrongConnID)) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { @@ -2412,8 +2413,8 @@ var _ = Describe("Client Connection", func() { tlsConf *tls.Config quicConf *Config ) - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { buf := &bytes.Buffer{} @@ -2452,7 +2453,7 @@ var _ = Describe("Client Connection", func() { mconn, connRunner, destConnID, - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), quicConf, tlsConf, 42, // initial packet number @@ -2485,7 +2486,7 @@ var _ = Describe("Client Connection", func() { cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) conn.run() }() - newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} + newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7}) p := getPacket(&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, @@ -2517,9 +2518,9 @@ var _ = Describe("Client Connection", func() { conn.connIDManager.SetHandshakeComplete() conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}), }) - Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) + Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) // now receive a packet with the original source connection ID unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { return &unpackedPacket{ @@ -2598,8 +2599,8 @@ var _ = Describe("Client Connection", func() { Context("handling Version Negotiation", func() { getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { b := wire.ComposeVersionNegotiation( - protocol.ArbitraryLenConnectionID(srcConnID), - protocol.ArbitraryLenConnectionID(destConnID), + protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), + protocol.ArbitraryLenConnectionID(destConnID.Bytes()), versions, ) return &receivedPacket{ @@ -2679,7 +2680,7 @@ var _ = Describe("Client Connection", func() { }) Context("handling Retry", func() { - origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + origDestConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) var retryHdr *wire.ExtendedHeader @@ -2688,8 +2689,8 @@ var _ = Describe("Client Connection", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Token: []byte("foobar"), Version: conn.version, }, @@ -2707,7 +2708,7 @@ var _ = Describe("Client Connection", func() { conn.sentPacketHandler = sph sph.EXPECT().ResetForRetry() sph.EXPECT().ReceivedBytes(gomock.Any()) - cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) + cryptoSetup.EXPECT().ChangeConnectionID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})) packer.EXPECT().SetToken([]byte("foobar")) tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { Expect(hdr.DestConnectionID).To(Equal(retryHdr.DestConnectionID)) @@ -2788,7 +2789,7 @@ var _ = Describe("Client Connection", func() { PreferredAddress: &wire.PreferredAddress{ IPv4: net.IPv4(127, 0, 0, 1), IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, }, } @@ -2801,7 +2802,7 @@ var _ = Describe("Client Connection", func() { cf, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(cf).To(BeEmpty()) connRunner.EXPECT().AddResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, conn) - Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) // shut down connRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) expectClose(true) @@ -2823,10 +2824,10 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters contain a wrong initial_source_connection_id", func() { - conn.handshakeDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + conn.handshakeDestConnID = protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } expectClose(false) @@ -2839,7 +2840,8 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters don't contain the retry_source_connection_id, if a Retry was performed", func() { - conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) + conn.retrySrcConnID = &rcid params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, @@ -2855,11 +2857,13 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters contain the wrong retry_source_connection_id, if a Retry was performed", func() { - conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) + rcid2 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) + conn.retrySrcConnID = &rcid params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + RetrySourceConnectionID: &rcid2, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } expectClose(false) @@ -2872,10 +2876,11 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters contain the retry_source_connection_id, if no Retry was performed", func() { + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + RetrySourceConnectionID: &rcid, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } expectClose(false) @@ -2888,9 +2893,9 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters contain a wrong original_destination_connection_id", func() { - conn.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + conn.origDestConnID = protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) params := &wire.TransportParameters{ - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), InitialSourceConnectionID: conn.handshakeDestConnID, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } @@ -2948,7 +2953,7 @@ var _ = Describe("Client Connection", func() { IsLongHeader: true, Type: protocol.PacketTypeInitial, DestConnectionID: destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), Length: 1, Version: conn.version, }, @@ -3014,7 +3019,7 @@ var _ = Describe("Client Connection", func() { conn.sentPacketHandler = sph sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) sph.EXPECT().ResetForRetry() - newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + newSrcConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) packer.EXPECT().SetToken([]byte("foobar")) diff --git a/framer_test.go b/framer_test.go index 71f7b2aa..13fe65fc 100644 --- a/framer_test.go +++ b/framer_test.go @@ -89,7 +89,10 @@ var _ = Describe("Framer", func() { It("drops *_BLOCKED frames when 0-RTT is rejected", func() { ping := &wire.PingFrame{} - ncid := &wire.NewConnectionIDFrame{SequenceNumber: 10, ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}} + ncid := &wire.NewConnectionIDFrame{ + SequenceNumber: 10, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + } frames := []wire.Frame{ &wire.DataBlockedFrame{MaximumData: 1337}, &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1337}, diff --git a/fuzzing/frames/cmd/corpus.go b/fuzzing/frames/cmd/corpus.go index eea06426..14a8360d 100644 --- a/fuzzing/frames/cmd/corpus.go +++ b/fuzzing/frames/cmd/corpus.go @@ -224,13 +224,13 @@ func getFrames() []wire.Frame { &wire.NewConnectionIDFrame{ SequenceNumber: seq1, RetirePriorTo: seq1 / 2, - ConnectionID: getRandomData(4), + ConnectionID: protocol.ParseConnectionID(getRandomData(4)), StatelessResetToken: token1, }, &wire.NewConnectionIDFrame{ SequenceNumber: seq2, RetirePriorTo: seq2, - ConnectionID: getRandomData(17), + ConnectionID: protocol.ParseConnectionID(getRandomData(17)), StatelessResetToken: token2, }, }...) diff --git a/fuzzing/header/cmd/corpus.go b/fuzzing/header/cmd/corpus.go index 1cccc4c6..0c02b699 100644 --- a/fuzzing/header/cmd/corpus.go +++ b/fuzzing/header/cmd/corpus.go @@ -31,23 +31,23 @@ func main() { headers := []wire.Header{ { // Initial without token IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(3)), - DestConnectionID: protocol.ConnectionID(getRandomData(8)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(3)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // Initial without token, with zero-length src conn id IsLongHeader: true, - DestConnectionID: protocol.ConnectionID(getRandomData(8)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // Initial with Token IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(10)), - DestConnectionID: protocol.ConnectionID(getRandomData(19)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(10)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(19)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, @@ -55,37 +55,37 @@ func main() { }, { // Handshake packet IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(5)), - DestConnectionID: protocol.ConnectionID(getRandomData(10)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(5)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(10)), Type: protocol.PacketTypeHandshake, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // Handshake packet, with zero-length src conn id IsLongHeader: true, - DestConnectionID: protocol.ConnectionID(getRandomData(12)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(12)), Type: protocol.PacketTypeHandshake, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // 0-RTT packet IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(8)), - DestConnectionID: protocol.ConnectionID(getRandomData(9)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(8)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(9)), Type: protocol.PacketType0RTT, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // Retry Packet, with empty orig dest conn id IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(8)), - DestConnectionID: protocol.ConnectionID(getRandomData(9)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(8)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(9)), Type: protocol.PacketTypeRetry, Token: getRandomData(1000), Version: version, }, { // Short-Header - DestConnectionID: protocol.ConnectionID(getRandomData(8)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), }, } diff --git a/fuzzing/tokens/fuzz.go b/fuzzing/tokens/fuzz.go index 1e1904ba..64d4576f 100644 --- a/fuzzing/tokens/fuzz.go +++ b/fuzzing/tokens/fuzz.go @@ -73,7 +73,7 @@ func newToken(tg *handshake.TokenGenerator, data []byte) int { if token.SentTime.Before(start) || token.SentTime.After(time.Now()) { panic("incorrect send time") } - if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil { + if token.OriginalDestConnectionID.Len() > 0 || token.RetrySrcConnectionID.Len() > 0 { panic("didn't expect connection IDs") } return 1 @@ -89,12 +89,12 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int { if len(data) < origDestConnIDLen { return -1 } - origDestConnID := protocol.ConnectionID(data[:origDestConnIDLen]) + origDestConnID := protocol.ParseConnectionID(data[:origDestConnIDLen]) data = data[origDestConnIDLen:] if len(data) < retrySrcConnIDLen { return -1 } - retrySrcConnID := protocol.ConnectionID(data[:retrySrcConnIDLen]) + retrySrcConnID := protocol.ParseConnectionID(data[:retrySrcConnIDLen]) data = data[retrySrcConnIDLen:] if len(data) < 1 { diff --git a/fuzzing/transportparameters/cmd/corpus.go b/fuzzing/transportparameters/cmd/corpus.go index 8c88a6a8..0f472061 100644 --- a/fuzzing/transportparameters/cmd/corpus.go +++ b/fuzzing/transportparameters/cmd/corpus.go @@ -43,13 +43,13 @@ func main() { ActiveConnectionIDLimit: getRandomValue(), } if rand.Int()%2 == 0 { - tp.OriginalDestinationConnectionID = protocol.ConnectionID(getRandomData(rand.Intn(50))) + tp.OriginalDestinationConnectionID = protocol.ParseConnectionID(getRandomData(rand.Intn(21))) } if rand.Int()%2 == 0 { - tp.InitialSourceConnectionID = protocol.ConnectionID(getRandomData(rand.Intn(50))) + tp.InitialSourceConnectionID = protocol.ParseConnectionID(getRandomData(rand.Intn(21))) } if rand.Int()%2 == 0 { - connID := protocol.ConnectionID(getRandomData(rand.Intn(50))) + connID := protocol.ParseConnectionID(getRandomData(rand.Intn(21))) tp.RetrySourceConnectionID = &connID } if rand.Int()%2 == 0 { @@ -65,7 +65,7 @@ func main() { IPv4Port: uint16(rand.Int()), IPv6: net.IP(getRandomData(16)), IPv6Port: uint16(rand.Int()), - ConnectionID: protocol.ConnectionID(getRandomData(rand.Intn(25))), + ConnectionID: protocol.ParseConnectionID(getRandomData(rand.Intn(21))), StatelessResetToken: token, } } diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index dc47aa86..760d7131 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -19,13 +19,12 @@ type connIDGenerator struct { length int } -func (c *connIDGenerator) GenerateConnectionID() ([]byte, error) { +func (c *connIDGenerator) GenerateConnectionID() (quic.ConnectionID, error) { b := make([]byte, c.length) - _, err := rand.Read(b) - if err != nil { + if _, err := rand.Read(b); err != nil { fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err) } - return b, nil + return protocol.ParseConnectionID(b), nil } func (c *connIDGenerator) ConnectionIDLen() int { diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index ff543fd4..0a08bb2e 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -367,7 +367,7 @@ var _ = Describe("MITM test", func() { } initialPacketIntercepted = true - fakeSrcConnID := protocol.ConnectionID{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12} + fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12}) retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) _, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr()) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index c39a02dd..e63c85d0 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -372,7 +372,7 @@ var _ = Describe("0-RTT", func() { It("retransmits all 0-RTT data when the server performs a Retry", func() { var mutex sync.Mutex - var firstConnID, secondConnID protocol.ConnectionID + var firstConnID, secondConnID *protocol.ConnectionID var firstCounter, secondCounter protocol.ByteCount tlsConf, clientConf := dialAndReceiveSessionTicket(nil) @@ -415,13 +415,13 @@ var _ = Describe("0-RTT", func() { if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 { if firstConnID == nil { - firstConnID = connID + firstConnID = &connID firstCounter += zeroRTTBytes } else if firstConnID != nil && firstConnID.Equal(connID) { Expect(secondConnID).To(BeNil()) firstCounter += zeroRTTBytes } else if secondConnID == nil { - secondConnID = connID + secondConnID = &connID secondCounter += zeroRTTBytes } else if secondConnID != nil && secondConnID.Equal(connID) { secondCounter += zeroRTTBytes diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index 35a56baa..3254d3df 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -35,8 +35,8 @@ var _ = Describe("QUIC Proxy", func() { Type: protocol.PacketTypeInitial, Version: protocol.VersionTLS, Length: 4 + protocol.ByteCount(len(payload)), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}), }, PacketNumber: p, PacketNumberLen: protocol.PacketNumberLen4, diff --git a/interface.go b/interface.go index ea94aa45..8fbeab1f 100644 --- a/interface.go +++ b/interface.go @@ -201,6 +201,11 @@ type EarlyConnection interface { NextConnection() Connection } +// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. +// It is not able to handle QUIC Connection IDs longer than 20 bytes, +// as they are allowed by RFC 8999. +type ConnectionID = protocol.ConnectionID + // A ConnectionIDGenerator is an interface that allows clients to implement their own format // for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets. // @@ -208,7 +213,7 @@ type EarlyConnection interface { type ConnectionIDGenerator interface { // GenerateConnectionID generates a new ConnectionID. // Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs. - GenerateConnectionID() ([]byte, error) + GenerateConnectionID() (ConnectionID, error) // ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of // this interface. diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index 00ed243c..6128147c 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -62,7 +62,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p } func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { - initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v)) + initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v)) clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) return diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go index bb8c4a15..a3f38ac6 100644 --- a/internal/handshake/initial_aead_test.go +++ b/internal/handshake/initial_aead_test.go @@ -18,7 +18,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { Expect(splitHexString("dead beef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) }) - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + connID := protocol.ParseConnectionID(splitHexString("0x8394c8f03e515708")) DescribeTable("computes the client key and IV", func(v protocol.VersionNumber, expectedClientSecret, expectedKey, expectedIV []byte) { @@ -160,7 +160,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { Context(fmt.Sprintf("using version %s", v), func() { It("seals and opens", func() { - connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} + connectionID := protocol.ParseConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}) clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient, v) serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, v) @@ -175,8 +175,8 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { }) It("doesn't work if initialized with different connection IDs", func() { - c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} - c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} + c1 := protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1}) + c2 := protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2}) clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient, v) _, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer, v) @@ -186,7 +186,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { }) It("encrypts und decrypts the header", func() { - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + connID := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient, v) serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer, v) diff --git a/internal/handshake/retry_test.go b/internal/handshake/retry_test.go index e1d3a215..017fa428 100644 --- a/internal/handshake/retry_test.go +++ b/internal/handshake/retry_test.go @@ -9,27 +9,30 @@ import ( var _ = Describe("Retry Integrity Check", func() { It("calculates retry integrity tags", func() { - fooTag := GetRetryIntegrityTag([]byte("foo"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) - barTag := GetRetryIntegrityTag([]byte("bar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + fooTag := GetRetryIntegrityTag([]byte("foo"), connID, protocol.VersionDraft29) + barTag := GetRetryIntegrityTag([]byte("bar"), connID, protocol.VersionDraft29) Expect(fooTag).ToNot(BeNil()) Expect(barTag).ToNot(BeNil()) Expect(*fooTag).ToNot(Equal(*barTag)) }) It("includes the original connection ID in the tag calculation", func() { - t1 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.Version1) - t2 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{4, 3, 2, 1}, protocol.Version1) + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + t1 := GetRetryIntegrityTag([]byte("foobar"), connID1, protocol.Version1) + t2 := GetRetryIntegrityTag([]byte("foobar"), connID2, protocol.Version1) Expect(*t1).ToNot(Equal(*t2)) }) It("uses the test vector from the draft, for old draft versions", func() { - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + connID := protocol.ParseConnectionID(splitHexString("0x8394c8f03e515708")) data := splitHexString("ffff00001d0008f067a5502a4262b574 6f6b656ed16926d81f6f9ca2953a8aa4 575e1e49") Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.VersionDraft29)[:]).To(Equal(data[len(data)-16:])) }) It("uses the test vector from the draft, for version 1", func() { - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + connID := protocol.ParseConnectionID(splitHexString("0x8394c8f03e515708")) data := splitHexString("ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f 0f2496ba") Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.Version1)[:]).To(Equal(data[len(data)-16:])) }) diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index a8dda91e..cda49466 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -65,8 +65,8 @@ func (g *TokenGenerator) NewRetryToken( data, err := asn1.Marshal(token{ IsRetryToken: true, RemoteAddr: encodeRemoteAddr(raddr), - OriginalDestConnectionID: origDestConnID, - RetrySrcConnectionID: retrySrcConnID, + OriginalDestConnectionID: origDestConnID.Bytes(), + RetrySrcConnectionID: retrySrcConnID.Bytes(), Timestamp: time.Now().UnixNano(), }) if err != nil { @@ -112,8 +112,8 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { encodedRemoteAddr: t.RemoteAddr, } if t.IsRetryToken { - token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) - token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) + token.OriginalDestConnectionID = protocol.ParseConnectionID(t.OriginalDestConnectionID) + token.RetrySrcConnectionID = protocol.ParseConnectionID(t.RetrySrcConnectionID) } return token, nil } diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 4d4be0f2..d674e72e 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -23,7 +23,7 @@ var _ = Describe("Token Generator", func() { It("generates a token", func() { ip := net.IPv4(127, 0, 0, 1) - token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, nil, nil) + token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) Expect(token).ToNot(BeEmpty()) }) @@ -36,7 +36,7 @@ var _ = Describe("Token Generator", func() { It("accepts a valid token", func() { addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(addr, nil, nil) + tokenEnc, err := tokenGen.NewRetryToken(addr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) @@ -48,16 +48,14 @@ var _ = Describe("Token Generator", func() { }) It("saves the connection ID", func() { - tokenEnc, err := tokenGen.NewRetryToken( - &net.UDPAddr{}, - protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - ) + connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) + connID2 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) + tokenEnc, err := tokenGen.NewRetryToken(&net.UDPAddr{}, connID1, connID2) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) - Expect(token.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(token.RetrySrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(token.OriginalDestConnectionID).To(Equal(connID1)) + Expect(token.RetrySrcConnectionID).To(Equal(connID2)) }) It("rejects invalid tokens", func() { @@ -103,7 +101,7 @@ var _ = Describe("Token Generator", func() { ip := net.ParseIP(addr) Expect(ip).ToNot(BeNil()) raddr := &net.UDPAddr{IP: ip, Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) + tokenEnc, err := tokenGen.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) @@ -114,7 +112,7 @@ var _ = Describe("Token Generator", func() { It("uses the string representation an address that is not a UDP address", func() { raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) + tokenEnc, err := tokenGen.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index c6f8f35c..d3c607ca 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -1,12 +1,14 @@ package protocol import ( - "bytes" "crypto/rand" + "errors" "fmt" "io" ) +var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length") + // An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999. // Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1 // restricts the length to 20 bytes. @@ -27,18 +29,32 @@ func (c ArbitraryLenConnectionID) String() string { return fmt.Sprintf("%x", c.Bytes()) } -// A ConnectionID in QUIC -type ConnectionID []byte - const maxConnectionIDLen = 20 +// A ConnectionID in QUIC +type ConnectionID struct { + b [20]byte + l uint8 +} + // GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID(len int) (ConnectionID, error) { - b := make([]byte, len) - if _, err := rand.Read(b); err != nil { - return nil, err +func GenerateConnectionID(l int) (ConnectionID, error) { + var c ConnectionID + c.l = uint8(l) + _, err := rand.Read(c.b[:l]) + return c, err +} + +// ParseConnectionID interprets b as a Connection ID. +// It panics if b is longer than 20 bytes. +func ParseConnectionID(b []byte) ConnectionID { + if len(b) > maxConnectionIDLen { + panic("invalid conn id length") } - return ConnectionID(b), nil + var c ConnectionID + c.l = uint8(len(b)) + copy(c.b[:c.l], b) + return c } // GenerateConnectionIDForInitial generates a connection ID for the Initial packet. @@ -46,39 +62,43 @@ func GenerateConnectionID(len int) (ConnectionID, error) { func GenerateConnectionIDForInitial() (ConnectionID, error) { r := make([]byte, 1) if _, err := rand.Read(r); err != nil { - return nil, err + return ConnectionID{}, err } - len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) - return GenerateConnectionID(len) + l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(l) } // ReadConnectionID reads a connection ID of length len from the given io.Reader. // It returns io.EOF if there are not enough bytes to read. -func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { - if len == 0 { - return nil, nil +func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) { + var c ConnectionID + if l == 0 { + return c, nil } - c := make(ConnectionID, len) - _, err := io.ReadFull(r, c) + if l > maxConnectionIDLen { + return c, ErrInvalidConnectionIDLen + } + c.l = uint8(l) + _, err := io.ReadFull(r, c.b[:l]) if err == io.ErrUnexpectedEOF { - return nil, io.EOF + return c, io.EOF } return c, err } // Equal says if two connection IDs are equal func (c ConnectionID) Equal(other ConnectionID) bool { - return bytes.Equal(c, other) + return c == other } // Len returns the length of the connection ID in bytes func (c ConnectionID) Len() int { - return len(c) + return int(c.l) } // Bytes returns the byte representation func (c ConnectionID) Bytes() []byte { - return []byte(c) + return c.b[:c.l] } func (c ConnectionID) String() string { @@ -92,7 +112,7 @@ type DefaultConnectionIDGenerator struct { ConnLen int } -func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) { +func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) { return GenerateConnectionID(d.ConnLen) } diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go index b9754f43..6b4cda61 100644 --- a/internal/protocol/connection_id_test.go +++ b/internal/protocol/connection_id_test.go @@ -44,10 +44,13 @@ var _ = Describe("Connection ID generation", func() { }) It("says if connection IDs are equal", func() { - c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + c1 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + c2 := ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) + c3 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) Expect(c1.Equal(c1)).To(BeTrue()) + Expect(c1.Equal(c3)).To(BeTrue()) Expect(c2.Equal(c2)).To(BeTrue()) + Expect(c2.Equal(c3)).To(BeFalse()) Expect(c1.Equal(c2)).To(BeFalse()) Expect(c2.Equal(c1)).To(BeFalse()) }) @@ -65,15 +68,21 @@ var _ = Describe("Connection ID generation", func() { Expect(err).To(MatchError(io.EOF)) }) - It("returns nil for a 0 length connection ID", func() { + It("returns a 0 length connection ID", func() { buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) c, err := ReadConnectionID(buf, 0) Expect(err).ToNot(HaveOccurred()) - Expect(c).To(BeNil()) + Expect(c.Len()).To(BeZero()) + }) + + It("errors when trying to read a too long connection ID", func() { + buf := bytes.NewBuffer(make([]byte, 21)) + _, err := ReadConnectionID(buf, 21) + Expect(err).To(MatchError(ErrInvalidConnectionIDLen)) }) It("returns the length", func() { - c := ConnectionID{1, 2, 3, 4, 5, 6, 7} + c := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) Expect(c.Len()).To(Equal(7)) }) @@ -83,22 +92,22 @@ var _ = Describe("Connection ID generation", func() { }) It("returns the bytes", func() { - c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) + c := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7})) }) It("returns a nil byte slice for the default value", func() { var c ConnectionID - Expect(c.Bytes()).To(BeNil()) + Expect(c.Bytes()).To(HaveLen(0)) }) It("has a string representation", func() { - c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) + c := ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) Expect(c.String()).To(Equal("deadbeef42")) }) It("has a long string representation", func() { - c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad} + c := ParseConnectionID([]byte{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}) Expect(c.String()).To(Equal("13370000decafbad")) }) diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go index 51719e83..a3f25fb3 100644 --- a/internal/wire/extended_header_test.go +++ b/internal/wire/extended_header_test.go @@ -24,15 +24,15 @@ var _ = Describe("Header", func() { }) Context("Long Header", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) It("writes", func() { Expect((&ExtendedHeader{ Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}), Version: 0x1020304, Length: protocol.InitialPacketSizeIPv4, }, @@ -52,27 +52,12 @@ var _ = Describe("Header", func() { Expect(buf.Bytes()).To(Equal(expected)) }) - It("refuses to write a header with a too long connection ID", func() { - err := (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, // connection IDs must be at most 20 bytes long - Version: 0x1020304, - Type: 0x5, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid connection ID length: 21 bytes")) - }) - It("writes a header with a 20 byte connection ID", func() { err := (&ExtendedHeader{ Header: Header{ IsLongHeader: true, SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, // connection IDs must be at most 20 bytes long + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long Version: 0x1020304, Type: 0x5, }, @@ -194,7 +179,7 @@ var _ = Describe("Header", func() { It("writes a header with connection ID", func() { Expect((&ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), }, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 0x42, @@ -271,8 +256,8 @@ var _ = Describe("Header", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Length: 1, }, PacketNumberLen: protocol.PacketNumberLen1, @@ -288,8 +273,8 @@ var _ = Describe("Header", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Length: 1500, }, PacketNumberLen: protocol.PacketNumberLen2, @@ -305,8 +290,8 @@ var _ = Describe("Header", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 15, }, PacketNumberLen: protocol.PacketNumberLen2, @@ -322,8 +307,8 @@ var _ = Describe("Header", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 1500, }, PacketNumberLen: protocol.PacketNumberLen2, @@ -338,8 +323,8 @@ var _ = Describe("Header", func() { h := &ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Type: protocol.PacketTypeInitial, Length: 1500, Token: []byte("foo"), @@ -355,7 +340,7 @@ var _ = Describe("Header", func() { It("has the right length for a Short Header containing a connection ID", func() { h := &ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), }, PacketNumberLen: protocol.PacketNumberLen1, } @@ -407,8 +392,8 @@ var _ = Describe("Header", func() { (&ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}), Type: protocol.PacketTypeHandshake, Length: 54321, Version: 0xfeed, @@ -423,8 +408,8 @@ var _ = Describe("Header", func() { (&ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeInitial, Token: []byte{0xde, 0xad, 0xbe, 0xef}, Length: 100, @@ -440,8 +425,8 @@ var _ = Describe("Header", func() { (&ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeInitial, Length: 100, Version: 0xfeed, @@ -456,8 +441,8 @@ var _ = Describe("Header", func() { (&ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeRetry, Token: []byte{0x12, 0x34, 0x56}, Version: 0xfeed, @@ -469,7 +454,7 @@ var _ = Describe("Header", func() { It("logs Short Headers containing a connection ID", func() { (&ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), }, KeyPhase: protocol.KeyPhaseOne, PacketNumber: 1337, diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index dcf570e8..8b88a378 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -210,7 +210,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks NEW_CONNECTION_ID frames", func() { f := &NewConnectionIDFrame{ SequenceNumber: 0x1337, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, } b, err := f.Append(nil, protocol.Version1) @@ -330,7 +330,7 @@ var _ = Describe("Frame parsing", func() { &DataBlockedFrame{}, &StreamDataBlockedFrame{}, &StreamsBlockedFrame{}, - &NewConnectionIDFrame{ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, + &NewConnectionIDFrame{ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})}, &RetireConnectionIDFrame{}, &PathChallengeFrame{}, &PathResponseFrame{}, diff --git a/internal/wire/header.go b/internal/wire/header.go index 4c7eb926..e8a08242 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -17,22 +17,22 @@ import ( // That means that the connection ID must not be used after the packet buffer is released. func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { if len(data) == 0 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } if !IsLongHeaderPacket(data[0]) { if len(data) < shortHeaderConnIDLen+1 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } - return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil + return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil } if len(data) < 6 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } destConnIDLen := int(data[5]) if len(data) < 6+destConnIDLen { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } - return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil + return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil } // ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet, diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index b0149f28..88b045de 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -20,36 +20,36 @@ var _ = Describe("Header Parsing", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}), Version: protocol.Version1, }, PacketNumberLen: 2, }).Write(buf, protocol.Version1)).To(Succeed()) connID, err := ParseConnectionID(buf.Bytes(), 8) Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) }) It("parses the connection ID of a short header packet", func() { buf := &bytes.Buffer{} Expect((&ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), }, PacketNumberLen: 2, }).Write(buf, protocol.Version1)).To(Succeed()) buf.Write([]byte("foobar")) connID, err := ParseConnectionID(buf.Bytes(), 4) Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) }) It("errors on EOF, for short header packets", func() { buf := &bytes.Buffer{} Expect((&ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), }, PacketNumberLen: 2, }).Write(buf, protocol.Version1)).To(Succeed()) @@ -70,8 +70,8 @@ var _ = Describe("Header Parsing", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 8, 9}), Version: protocol.Version1, }, PacketNumberLen: 2, @@ -194,14 +194,14 @@ var _ = Describe("Header Parsing", func() { Context("Long Headers", func() { It("parses a Long Header", func() { - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + destConnID := protocol.ParseConnectionID([]byte{9, 8, 7, 6, 5, 4, 3, 2, 1}) + srcConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) data := []byte{0xc0 ^ 0x3} data = appendVersion(data, protocol.Version1) data = append(data, 0x9) // dest conn id length - data = append(data, destConnID...) + data = append(data, destConnID.Bytes()...) data = append(data, 0x4) // src conn id length - data = append(data, srcConnID...) + data = append(data, srcConnID.Bytes()...) data = append(data, encodeVarInt(6)...) // token length data = append(data, []byte("foobar")...) // token data = append(data, encodeVarInt(10)...) // length @@ -256,38 +256,50 @@ var _ = Describe("Header Parsing", func() { Expect(err).To(MatchError(ErrUnsupportedVersion)) Expect(hdr.IsLongHeader).To(BeTrue()) Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef))) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1})) + Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}))) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1}))) Expect(rest).To(BeEmpty()) }) It("parses a Long Header without a destination connection ID", func() { data := []byte{0xc0 ^ 0x1<<4} data = appendVersion(data, protocol.Version1) - data = append(data, 0x0) // dest conn ID len - data = append(data, 0x4) // src conn ID len + data = append(data, 0) // dest conn ID len + data = append(data, 4) // src conn ID len data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(hdr.DestConnectionID).To(BeEmpty()) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}))) + Expect(hdr.DestConnectionID).To(BeZero()) }) It("parses a Long Header without a source connection ID", func() { data := []byte{0xc0 ^ 0x2<<4} data = appendVersion(data, protocol.Version1) - data = append(data, 0xa) // dest conn ID len + data = append(data, 10) // dest conn ID len data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID - data = append(data, 0x0) // src conn ID len + data = append(data, 0) // src conn ID len data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.SrcConnectionID).To(BeEmpty()) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.SrcConnectionID).To(BeZero()) + Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) + }) + + It("parses a Long Header without a too long destination connection ID", func() { + data := []byte{0xc0 ^ 0x2<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, 21) // dest conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // dest connection ID + data = append(data, 0x0) // src conn ID len + data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) + _, _, _, err := ParsePacket(data, 0) + Expect(err).To(MatchError(protocol.ErrInvalidConnectionIDLen)) }) It("parses a Long Header with a 2 byte packet number", func() { @@ -321,8 +333,8 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(hdr.Version).To(Equal(protocol.Version1)) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{6, 5, 4, 3, 2, 1}))) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) Expect(hdr.Token).To(Equal([]byte("foobar"))) Expect(pdata).To(Equal(data)) Expect(rest).To(BeEmpty()) @@ -341,8 +353,8 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(hdr.Version).To(Equal(protocol.Version2)) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{6, 5, 4, 3, 2, 1}))) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) Expect(hdr.Token).To(Equal([]byte("foobar"))) Expect(pdata).To(Equal(data)) Expect(rest).To(BeEmpty()) @@ -439,7 +451,7 @@ var _ = Describe("Header Parsing", func() { hdr := Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 2 + 6, Version: protocol.Version1, } @@ -465,7 +477,7 @@ var _ = Describe("Header Parsing", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 3, Version: protocol.Version1, }, @@ -483,7 +495,7 @@ var _ = Describe("Header Parsing", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 1000, Version: protocol.Version1, }, @@ -499,8 +511,8 @@ var _ = Describe("Header Parsing", func() { Context("Short Headers", func() { It("reads a Short Header with a 8 byte connection ID", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - data := append([]byte{0x40}, connID...) + connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}) + data := append([]byte{0x40}, connID.Bytes()...) data = append(data, 0x42) // packet number Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) @@ -513,7 +525,7 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeEmpty()) + Expect(extHdr.SrcConnectionID).To(BeZero()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1)) Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1)) @@ -522,15 +534,15 @@ var _ = Describe("Header Parsing", func() { }) It("errors if 0x40 is not set", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - data := append([]byte{0x0}, connID...) + connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}) + data := append([]byte{0x0}, connID.Bytes()...) _, _, _, err := ParsePacket(data, 8) Expect(err).To(MatchError("not a QUIC packet")) }) It("errors if the 4th or 5th bit are set", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID...) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID.Bytes()...) data = append(data, 0x42) // packet number hdr, _, _, err := ParsePacket(data, 5) Expect(err).ToNot(HaveOccurred()) @@ -542,8 +554,8 @@ var _ = Describe("Header Parsing", func() { }) It("reads a Short Header with a 5 byte connection ID", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - data := append([]byte{0x40}, connID...) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + data := append([]byte{0x40}, connID.Bytes()...) data = append(data, 0x42) // packet number hdr, pdata, rest, err := ParsePacket(data, 5) Expect(err).ToNot(HaveOccurred()) @@ -555,7 +567,7 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeEmpty()) + Expect(extHdr.SrcConnectionID).To(BeZero()) Expect(rest).To(BeEmpty()) }) diff --git a/internal/wire/log_test.go b/internal/wire/log_test.go index 6e970c1c..c7913804 100644 --- a/internal/wire/log_test.go +++ b/internal/wire/log_test.go @@ -153,7 +153,7 @@ var _ = Describe("Frame logging", func() { It("logs NEW_CONNECTION_ID frames", func() { LogFrame(logger, &NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, }, false) Expect(buf.String()).To(ContainSubstring("\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}")) diff --git a/internal/wire/new_connection_id_frame.go b/internal/wire/new_connection_id_frame.go index befc4037..828cda3b 100644 --- a/internal/wire/new_connection_id_frame.go +++ b/internal/wire/new_connection_id_frame.go @@ -38,9 +38,6 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC if err != nil { return nil, err } - if connIDLen > protocol.MaxConnIDLen { - return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) - } connID, err := protocol.ReadConnectionID(r, int(connIDLen)) if err != nil { return nil, err diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go index fa9f53aa..f289cb65 100644 --- a/internal/wire/new_connection_id_frame_test.go +++ b/internal/wire/new_connection_id_frame_test.go @@ -24,7 +24,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { Expect(err).ToNot(HaveOccurred()) Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) Expect(frame.RetirePriorTo).To(Equal(uint64(0xcafe))) - Expect(frame.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(frame.ConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) Expect(string(frame.StatelessResetToken[:])).To(Equal("deadbeefdecafbad")) }) @@ -49,7 +49,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token b := bytes.NewReader(data) _, err := parseNewConnectionIDFrame(b, protocol.Version1) - Expect(err).To(MatchError("invalid connection ID length: 21")) + Expect(err).To(MatchError(protocol.ErrInvalidConnectionIDLen)) }) It("errors on EOFs", func() { @@ -74,7 +74,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { frame := &NewConnectionIDFrame{ SequenceNumber: 0x1337, RetirePriorTo: 0x42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}), StatelessResetToken: token, } b, err := frame.Append(nil, protocol.Version1) @@ -93,7 +93,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { frame := &NewConnectionIDFrame{ SequenceNumber: 0xdecafbad, RetirePriorTo: 0xdeadbeefcafe, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), StatelessResetToken: token, } b, err := frame.Append(nil, protocol.Version1) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 81d33c60..b5f478fb 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -41,6 +41,7 @@ var _ = Describe("Transport Parameters", func() { } It("has a string representation", func() { + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) p := &TransportParameters{ InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, @@ -49,9 +50,9 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42 * time.Second, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), + RetrySourceConnectionID: &rcid, AckDelayExponent: 14, MaxAckDelay: 37 * time.Millisecond, StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, @@ -70,8 +71,8 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42 * time.Second, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{}), AckDelayExponent: 14, MaxAckDelay: 37 * time.Second, ActiveConnectionIDLimit: 89, @@ -83,6 +84,7 @@ var _ = Describe("Transport Parameters", func() { It("marshals and unmarshals", func() { var token protocol.StatelessResetToken rand.Read(token[:]) + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) params := &TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), @@ -93,9 +95,9 @@ var _ = Describe("Transport Parameters", func() { MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), DisableActiveMigration: true, StatelessResetToken: &token, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), + RetrySourceConnectionID: &rcid, AckDelayExponent: 13, MaxAckDelay: 42 * time.Millisecond, ActiveConnectionIDLimit: getRandomValue(), @@ -114,9 +116,9 @@ var _ = Describe("Transport Parameters", func() { Expect(p.MaxIdleTimeout).To(Equal(params.MaxIdleTimeout)) Expect(p.DisableActiveMigration).To(Equal(params.DisableActiveMigration)) Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken)) - Expect(p.OriginalDestinationConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(p.InitialSourceConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - Expect(p.RetrySourceConnectionID).To(Equal(&protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(p.OriginalDestinationConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}))) + Expect(p.InitialSourceConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) + Expect(p.RetrySourceConnectionID).To(Equal(&rcid)) Expect(p.AckDelayExponent).To(Equal(uint8(13))) Expect(p.MaxAckDelay).To(Equal(42 * time.Millisecond)) Expect(p.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) @@ -133,8 +135,9 @@ var _ = Describe("Transport Parameters", func() { }) It("marshals a zero-length retry_source_connection_id", func() { + rcid := protocol.ParseConnectionID([]byte{}) data := (&TransportParameters{ - RetrySourceConnectionID: &protocol.ConnectionID{}, + RetrySourceConnectionID: &rcid, StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} @@ -406,7 +409,7 @@ var _ = Describe("Transport Parameters", func() { IPv4Port: 42, IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, IPv6Port: 13, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, } }) @@ -439,7 +442,7 @@ var _ = Describe("Transport Parameters", func() { }) It("errors on zero-length connection IDs", func() { - pa.ConnectionID = protocol.ConnectionID{} + pa.ConnectionID = protocol.ParseConnectionID([]byte{}) data := (&TransportParameters{ PreferredAddress: pa, StatelessResetToken: &protocol.StatelessResetToken{}, @@ -451,20 +454,6 @@ var _ = Describe("Transport Parameters", func() { })) }) - It("errors on too long connection IDs", func() { - pa.ConnectionID = protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21} - Expect(pa.ConnectionID.Len()).To(BeNumerically(">", protocol.MaxConnIDLen)) - data := (&TransportParameters{ - PreferredAddress: pa, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid connection ID length: 21", - })) - }) - It("errors on EOF", func() { raw := []byte{ 127, 0, 0, 1, // IPv4 diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 55c45564..f5a888de 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -6,6 +6,8 @@ import ( "net" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -38,18 +40,20 @@ var _ = Describe("Tracing", func() { It("multiplexes the TracerForConnection call", func() { ctx := context.Background() - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + connID := protocol.ParseConnectionID([]byte{1, 2, 3}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) + tracer.TracerForConnection(ctx, PerspectiveClient, connID) }) It("uses multiple connection tracers", func() { ctx := context.Background() ctr1 := NewMockConnectionTracer(mockCtrl) ctr2 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + connID := protocol.ParseConnectionID([]byte{1, 2, 3}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr2) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID) ctr1.EXPECT().LossTimerCanceled() ctr2.EXPECT().LossTimerCanceled() tr.LossTimerCanceled() @@ -58,23 +62,25 @@ var _ = Describe("Tracing", func() { It("handles tracers that return a nil ConnectionTracer", func() { ctx := context.Background() ctr1 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID) ctr1.EXPECT().LossTimerCanceled() tr.LossTimerCanceled() }) It("returns nil when all tracers return a nil ConnectionTracer", func() { ctx := context.Background() - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) + Expect(tracer.TracerForConnection(ctx, PerspectiveClient, connID)).To(BeNil()) }) It("traces the PacketSent event", func() { remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} f := &MaxDataFrame{MaximumData: 1337} tr1.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) tr2.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) @@ -116,9 +122,11 @@ var _ = Describe("Tracing", func() { It("trace the ConnectionStarted event", func() { local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - tr1.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - tr2.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - tracer.StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) + dest := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + src := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + tr1.EXPECT().StartedConnection(local, remote, src, dest) + tr2.EXPECT().StartedConnection(local, remote, src, dest) + tracer.StartedConnection(local, remote, src, dest) }) It("traces the ClosedConnection event", func() { @@ -150,7 +158,7 @@ var _ = Describe("Tracing", func() { }) It("traces the SentPacket event", func() { - hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} + hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} ping := &PingFrame{} tr1.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) @@ -167,14 +175,14 @@ var _ = Describe("Tracing", func() { }) It("traces the ReceivedRetry event", func() { - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} tr1.EXPECT().ReceivedRetry(hdr) tr2.EXPECT().ReceivedRetry(hdr) tracer.ReceivedRetry(hdr) }) It("traces the ReceivedPacket event", func() { - hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} + hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} ping := &PingFrame{} tr1.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) tr2.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) diff --git a/packet_handler_map.go b/packet_handler_map.go index 0caa4907..6018765a 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -48,7 +48,7 @@ type packetHandlerMap struct { closeQueue chan closePacket - handlers map[string] /* string(ConnectionID)*/ packetHandler + handlers map[protocol.ConnectionID]packetHandler resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler server unknownPacketHandler numZeroRTTEntries int @@ -127,7 +127,7 @@ func newPacketHandlerMap( conn: conn, connIDLen: connIDLen, listening: make(chan struct{}), - handlers: make(map[string]packetHandler), + handlers: make(map[protocol.ConnectionID]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, @@ -176,11 +176,11 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) h.mutex.Lock() defer h.mutex.Unlock() - if _, ok := h.handlers[string(id)]; ok { + if _, ok := h.handlers[id]; ok { h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) return false } - h.handlers[string(id)] = handler + h.handlers[id] = handler h.logger.Debugf("Adding connection ID %s.", id) return true } @@ -190,7 +190,7 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co defer h.mutex.Unlock() var q *zeroRTTQueue - if handler, ok := h.handlers[string(clientDestConnID)]; ok { + if handler, ok := h.handlers[clientDestConnID]; ok { q, ok = handler.(*zeroRTTQueue) if !ok { h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) @@ -206,15 +206,15 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co if q != nil { q.EnqueueAll(conn) } - h.handlers[string(clientDestConnID)] = conn - h.handlers[string(newConnID)] = conn + h.handlers[clientDestConnID] = conn + h.handlers[newConnID] = conn h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) return true } func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { h.mutex.Lock() - delete(h.handlers, string(id)) + delete(h.handlers, id) h.mutex.Unlock() h.logger.Debugf("Removing connection ID %s.", id) } @@ -223,7 +223,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) time.AfterFunc(h.deleteRetiredConnsAfter, func() { h.mutex.Lock() - delete(h.handlers, string(id)) + delete(h.handlers, id) h.mutex.Unlock() h.logger.Debugf("Removing connection ID %s after it has been retired.", id) }) @@ -254,7 +254,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p h.mutex.Lock() for _, id := range ids { - h.handlers[string(id)] = handler + h.handlers[id] = handler } h.mutex.Unlock() h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids) @@ -263,7 +263,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p h.mutex.Lock() handler.shutdown() for _, id := range ids { - delete(h.handlers, string(id)) + delete(h.handlers, id) } h.mutex.Unlock() h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids) @@ -394,7 +394,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { return } - if handler, ok := h.handlers[string(connID)]; ok { + if handler, ok := h.handlers[connID]; ok { if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue if wire.Is0RTTPacket(p.data) { ha.handlePacket(p) @@ -419,15 +419,15 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { } h.numZeroRTTEntries++ queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} - h.handlers[string(connID)] = queue + h.handlers[connID] = queue queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { h.mutex.Lock() defer h.mutex.Unlock() // The entry might have been replaced by an actual connection. // Only delete it if it's still a 0-RTT queue. - if handler, ok := h.handlers[string(connID)]; ok { + if handler, ok := h.handlers[connID]; ok { if q, ok := handler.(*zeroRTTQueue); ok { - delete(h.handlers, string(connID)) + delete(h.handlers, connID) h.numZeroRTTEntries-- if h.numZeroRTTEntries < 0 { panic("number of 0-RTT queues < 0") diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 63f8e853..cc4b9421 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -93,8 +93,8 @@ var _ = Describe("Packet Handler Map", func() { conn1.EXPECT().destroy(testErr) conn2 := NewMockPacketHandler(mockCtrl) conn2.EXPECT().destroy(testErr) - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, conn1) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, conn2) + handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1) + handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2) mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) handler.close(testErr) close(packetChan) @@ -123,8 +123,8 @@ var _ = Describe("Packet Handler Map", func() { }) It("handles packets for different packet handlers on the same packet conn", func() { - connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) packetHandler1 := NewMockPacketHandler(mockCtrl) packetHandler2 := NewMockPacketHandler(mockCtrl) handledPacket1 := make(chan struct{}) @@ -162,7 +162,7 @@ var _ = Describe("Packet Handler Map", func() { It("deletes removed connections immediately", func() { handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) handler.handlePacket(&receivedPacket{data: getPacket(connID)}) @@ -171,7 +171,7 @@ var _ = Describe("Packet Handler Map", func() { It("deletes retired connection entries after a wait time", func() { handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) conn := NewMockPacketHandler(mockCtrl) handler.Add(connID, conn) handler.Retire(connID) @@ -182,7 +182,7 @@ var _ = Describe("Packet Handler Map", func() { It("passes packets arriving late for closed connections to that connection", func() { handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) packetHandler := NewMockPacketHandler(mockCtrl) handled := make(chan struct{}) packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { @@ -195,7 +195,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("drops packets for unknown receivers", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) handler.handlePacket(&receivedPacket{data: getPacket(connID)}) }) @@ -206,14 +206,14 @@ var _ = Describe("Packet Handler Map", func() { Expect(e).To(HaveOccurred()) close(done) }) - handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler) packetChan <- packetToRead{err: errors.New("read failed")} Eventually(done).Should(BeClosed()) }) It("continues listening for temporary errors", func() { packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler) err := deadlineError{} Expect(err.Temporary()).To(BeTrue()) packetChan <- packetToRead{err: err} @@ -222,15 +222,15 @@ var _ = Describe("Packet Handler Map", func() { }) It("says if a connection ID is already taken", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) }) It("says if a connection ID is already taken, for AddWithConnID", func() { - clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - newConnID1 := protocol.ConnectionID{1, 2, 3, 4} - newConnID2 := protocol.ConnectionID{4, 3, 2, 1} + clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) }) @@ -238,7 +238,7 @@ var _ = Describe("Packet Handler Map", func() { Context("running a server", func() { It("adds a server", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { @@ -258,13 +258,13 @@ var _ = Describe("Packet Handler Map", func() { serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) serverConn.EXPECT().shutdown() - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientConn) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverConn) + handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn) + handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn) handler.CloseServer() }) It("stops handling packets with unknown connection IDs after the server is closed", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) // don't EXPECT any calls to server.handlePacket @@ -286,7 +286,7 @@ var _ = Describe("Packet Handler Map", func() { server := NewMockUnknownPacketHandler(mockCtrl) // don't EXPECT any calls to server.handlePacket handler.SetServer(server) - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)} p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)} @@ -300,14 +300,14 @@ var _ = Describe("Packet Handler Map", func() { conn.EXPECT().handlePacket(p2), conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), ) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) Eventually(done).Should(BeClosed()) }) It("directs 0-RTT packets to existing connections", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} conn.EXPECT().handlePacket(p1) handler.handlePacket(p1) @@ -315,17 +315,21 @@ var _ = Describe("Packet Handler Map", func() { It("limits the number of 0-RTT queues", func() { for i := 0; i < protocol.Max0RTTQueues; i++ { - connID := make(protocol.ConnectionID, 8) - rand.Read(connID) - p := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} + b := make([]byte, 8) + rand.Read(b) + p := &receivedPacket{data: getPacketWithPacketType( + protocol.ParseConnectionID(b), + protocol.PacketType0RTT, + 1, + )} handler.handlePacket(p) } // We're already storing the maximum number of queues. This packet will be dropped. - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) // Don't EXPECT any handlePacket() calls. conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) time.Sleep(20 * time.Millisecond) }) @@ -336,7 +340,7 @@ var _ = Describe("Packet Handler Map", func() { server := NewMockUnknownPacketHandler(mockCtrl) // don't EXPECT any calls to server.handlePacket handler.SetServer(server) - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) p1 := &receivedPacket{ data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1), buffer: getPacketBuffer(), @@ -351,7 +355,7 @@ var _ = Describe("Packet Handler Map", func() { time.Sleep(queueDuration * 3) // Don't EXPECT any handlePacket() calls. conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) time.Sleep(20 * time.Millisecond) }) }) @@ -404,7 +408,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("removes reset tokens", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) packetHandler := NewMockPacketHandler(mockCtrl) handler.Add(connID, packetHandler) token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} @@ -442,8 +446,8 @@ var _ = Describe("Packet Handler Map", func() { }) It("generates stateless reset tokens", func() { - connID1 := []byte{0xde, 0xad, 0xbe, 0xef} - connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) + connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) }) diff --git a/packet_packer_test.go b/packet_packer_test.go index ac095939..6b50a82c 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -38,7 +38,7 @@ var _ = Describe("Packet packer", func() { sealingManager *MockSealingManager pnManager *mockackhandler.MockSentPacketHandler ) - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) parsePacket := func(data []byte) []*wire.ExtendedHeader { var hdrs []*wire.ExtendedHeader @@ -94,7 +94,7 @@ var _ = Describe("Packet packer", func() { datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger) packer = newPacketPacker( - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), func() protocol.ConnectionID { return connID }, initialStream, handshakeStream, @@ -141,8 +141,8 @@ var _ = Describe("Packet packer", func() { It("sets source and destination connection ID", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) packer.srcConnID = srcConnID packer.getDestConnID = func() protocol.ConnectionID { return destConnID } h := packer.getLongHeader(protocol.EncryptionHandshake) @@ -616,7 +616,7 @@ var _ = Describe("Packet packer", func() { Expect(packet.packets).To(HaveLen(1)) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, packer.getDestConnID().Len()) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(packet.buffer.Data) extHdr, err := hdr.ParseExtended(r, packer.version) @@ -656,7 +656,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, packer.getDestConnID().Len()) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(packet.buffer.Data) extHdr, err := hdr.ParseExtended(r, packer.version) @@ -1206,7 +1206,7 @@ var _ = Describe("Packet packer", func() { Expect(packet.packets).To(HaveLen(1)) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, packer.getDestConnID().Len()) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(packet.buffer.Data) extHdr, err := hdr.ParseExtended(r, packer.version) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 16c708e6..813cf82a 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -23,7 +23,7 @@ var _ = Describe("Packet Unpacker", func() { var ( unpacker *packetUnpacker cs *mocks.MockCryptoSetup - connID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + connID = protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) payload = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ) diff --git a/qlog/event.go b/qlog/event.go index 90c38995..83423aa5 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -80,8 +80,8 @@ func (e eventConnectionStarted) MarshalJSONObject(enc *gojay.Encoder) { enc.IntKey("src_port", e.SrcAddr.Port) enc.StringKey("dst_ip", e.DestAddr.IP.String()) enc.IntKey("dst_port", e.DestAddr.Port) - enc.StringKey("src_cid", connectionID(e.SrcConnectionID).String()) - enc.StringKey("dst_cid", connectionID(e.DestConnectionID).String()) + enc.StringKey("src_cid", e.SrcConnectionID.String()) + enc.StringKey("dst_cid", e.DestConnectionID.String()) } type eventVersionNegotiated struct { @@ -410,15 +410,15 @@ func (e eventTransportParameters) MarshalJSONObject(enc *gojay.Encoder) { if !e.Restore { enc.StringKey("owner", e.Owner.String()) if e.SentBy == protocol.PerspectiveServer { - enc.StringKey("original_destination_connection_id", connectionID(e.OriginalDestinationConnectionID).String()) + enc.StringKey("original_destination_connection_id", e.OriginalDestinationConnectionID.String()) if e.StatelessResetToken != nil { enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", e.StatelessResetToken[:])) } if e.RetrySourceConnectionID != nil { - enc.StringKey("retry_source_connection_id", connectionID(*e.RetrySourceConnectionID).String()) + enc.StringKey("retry_source_connection_id", (*e.RetrySourceConnectionID).String()) } } - enc.StringKey("initial_source_connection_id", connectionID(e.InitialSourceConnectionID).String()) + enc.StringKey("initial_source_connection_id", e.InitialSourceConnectionID.String()) } enc.BoolKey("disable_active_migration", e.DisableActiveMigration) enc.FloatKeyOmitEmpty("max_idle_timeout", milliseconds(e.MaxIdleTimeout)) @@ -457,7 +457,7 @@ func (a preferredAddress) MarshalJSONObject(enc *gojay.Encoder) { enc.Uint16Key("port_v4", a.PortV4) enc.StringKey("ip_v6", a.IPv6.String()) enc.Uint16Key("port_v6", a.PortV6) - enc.StringKey("connection_id", connectionID(a.ConnectionID).String()) + enc.StringKey("connection_id", a.ConnectionID.String()) enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", a.StatelessResetToken)) } diff --git a/qlog/frame.go b/qlog/frame.go index 4530f0fb..35761dae 100644 --- a/qlog/frame.go +++ b/qlog/frame.go @@ -182,7 +182,7 @@ func marshalNewConnectionIDFrame(enc *gojay.Encoder, f *logging.NewConnectionIDF enc.Int64Key("sequence_number", int64(f.SequenceNumber)) enc.Int64Key("retire_prior_to", int64(f.RetirePriorTo)) enc.IntKey("length", f.ConnectionID.Len()) - enc.StringKey("connection_id", connectionID(f.ConnectionID).String()) + enc.StringKey("connection_id", f.ConnectionID.String()) enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", f.StatelessResetToken)) } diff --git a/qlog/frame_test.go b/qlog/frame_test.go index b5e553e8..bb98f9f8 100644 --- a/qlog/frame_test.go +++ b/qlog/frame_test.go @@ -273,7 +273,7 @@ var _ = Describe("Frames", func() { &logging.NewConnectionIDFrame{ SequenceNumber: 42, RetirePriorTo: 24, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, }, map[string]interface{}{ diff --git a/qlog/packet_header.go b/qlog/packet_header.go index 2fef2033..0b77936d 100644 --- a/qlog/packet_header.go +++ b/qlog/packet_header.go @@ -81,12 +81,12 @@ func (h packetHeader) MarshalJSONObject(enc *gojay.Encoder) { if h.PacketType != logging.PacketType1RTT { enc.IntKey("scil", h.SrcConnectionID.Len()) if h.SrcConnectionID.Len() > 0 { - enc.StringKey("scid", connectionID(h.SrcConnectionID).String()) + enc.StringKey("scid", h.SrcConnectionID.String()) } } enc.IntKey("dcil", h.DestConnectionID.Len()) if h.DestConnectionID.Len() > 0 { - enc.StringKey("dcid", connectionID(h.DestConnectionID).String()) + enc.StringKey("dcid", h.DestConnectionID.String()) } if h.KeyPhaseBit == logging.KeyPhaseZero || h.KeyPhaseBit == logging.KeyPhaseOne { enc.StringKey("key_phase_bit", h.KeyPhaseBit.String()) diff --git a/qlog/packet_header_test.go b/qlog/packet_header_test.go index 54fe782a..f5b2c033 100644 --- a/qlog/packet_header_test.go +++ b/qlog/packet_header_test.go @@ -97,7 +97,7 @@ var _ = Describe("Packet Header", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0x11, 0x22, 0x33, 0x44}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44}), Version: protocol.VersionNumber(0xdecafbad), Token: []byte{0xde, 0xad, 0xbe, 0xef}, }, @@ -140,7 +140,7 @@ var _ = Describe("Packet Header", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - SrcConnectionID: protocol.ConnectionID{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}), Version: protocol.VersionNumber(0xdecafbad), }, }, @@ -159,7 +159,7 @@ var _ = Describe("Packet Header", func() { check( &wire.ExtendedHeader{ PacketNumber: 42, - Header: wire.Header{DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, + Header: wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})}, KeyPhase: protocol.KeyPhaseOne, }, map[string]interface{}{ diff --git a/qlog/qlog.go b/qlog/qlog.go index 3c921e03..7adfc790 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -108,8 +108,8 @@ func (t *connectionTracer) run() { trace: trace{ VantagePoint: vantagePoint{Type: t.perspective}, CommonFields: commonFields{ - ODCID: connectionID(t.odcid), - GroupID: connectionID(t.odcid), + ODCID: t.odcid, + GroupID: t.odcid, ReferenceTime: t.referenceTime, }, }, diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index 350df017..c37b6da3 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -54,7 +54,11 @@ var _ = Describe("Tracing", func() { Context("tracer", func() { It("returns nil when there's no io.WriteCloser", func() { t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil }) - Expect(t.TracerForConnection(context.Background(), logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil()) + Expect(t.TracerForConnection( + context.Background(), + logging.PerspectiveClient, + protocol.ParseConnectionID([]byte{1, 2, 3, 4}), + )).To(BeNil()) }) }) @@ -63,7 +67,7 @@ var _ = Describe("Tracing", func() { t := NewConnectionTracer( &limitedWriter{WriteCloser: nopWriteCloser(buf), N: 250}, protocol.PerspectiveServer, - protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), ) for i := uint32(0); i < 1000; i++ { t.UpdatedPTOCount(i) @@ -85,7 +89,11 @@ var _ = Describe("Tracing", func() { BeforeEach(func() { buf = &bytes.Buffer{} t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) }) - tracer = t.TracerForConnection(context.Background(), logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef}) + tracer = t.TracerForConnection( + context.Background(), + logging.PerspectiveServer, + protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + ) }) It("exports a trace that has the right metadata", func() { @@ -155,8 +163,8 @@ var _ = Describe("Tracing", func() { tracer.StartedConnection( &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 42}, &net.UDPAddr{IP: net.IPv4(192, 168, 12, 34), Port: 24}, - protocol.ConnectionID{1, 2, 3, 4}, - protocol.ConnectionID{5, 6, 7, 8}, + protocol.ParseConnectionID([]byte{1, 2, 3, 4}), + protocol.ParseConnectionID([]byte{5, 6, 7, 8}), ) entry := exportAndParseSingle() Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) @@ -274,6 +282,7 @@ var _ = Describe("Tracing", func() { }) It("records sent transport parameters", func() { + rcid := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) tracer.SentTransportParameters(&logging.TransportParameters{ InitialMaxStreamDataBidiLocal: 1000, InitialMaxStreamDataBidiRemote: 2000, @@ -287,9 +296,9 @@ var _ = Describe("Tracing", func() { MaxUDPPayloadSize: 1234, MaxIdleTimeout: 321 * time.Millisecond, StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + RetrySourceConnectionID: &rcid, ActiveConnectionIDLimit: 7, MaxDatagramFrameSize: protocol.InvalidByteCount, }) @@ -318,7 +327,7 @@ var _ = Describe("Tracing", func() { It("records the server's transport parameters, without a stateless reset token", func() { tracer.SentTransportParameters(&logging.TransportParameters{ - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), ActiveConnectionIDLimit: 7, }) entry := exportAndParseSingle() @@ -347,7 +356,7 @@ var _ = Describe("Tracing", func() { IPv4Port: 123, IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, IPv6Port: 456, - ConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + ConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), StatelessResetToken: protocol.StatelessResetToken{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, }, }) @@ -417,8 +426,8 @@ var _ = Describe("Tracing", func() { Header: logging.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), Length: 1337, Version: protocol.VersionTLS, }, @@ -454,7 +463,7 @@ var _ = Describe("Tracing", func() { It("records a sent packet, without an ACK", func() { tracer.SentPacket( &logging.ExtendedHeader{ - Header: logging.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}}, + Header: logging.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4})}, PacketNumber: 1337, }, 123, @@ -483,8 +492,8 @@ var _ = Describe("Tracing", func() { Header: logging.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), Token: []byte{0xde, 0xad, 0xbe, 0xef}, Length: 1234, Version: protocol.VersionTLS, @@ -522,8 +531,8 @@ var _ = Describe("Tracing", func() { &logging.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), Token: []byte{0xde, 0xad, 0xbe, 0xef}, Version: protocol.VersionTLS, }, diff --git a/qlog/trace.go b/qlog/trace.go index 4f0b5e64..cf61558a 100644 --- a/qlog/trace.go +++ b/qlog/trace.go @@ -3,6 +3,8 @@ package qlog import ( "time" + "github.com/lucas-clemente/quic-go/logging" + "github.com/francoispqt/gojay" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -38,8 +40,8 @@ func (p vantagePoint) MarshalJSONObject(enc *gojay.Encoder) { } type commonFields struct { - ODCID connectionID - GroupID connectionID + ODCID logging.ConnectionID + GroupID logging.ConnectionID ProtocolType string ReferenceTime time.Time } diff --git a/qlog/types.go b/qlog/types.go index b485e17d..42e562f9 100644 --- a/qlog/types.go +++ b/qlog/types.go @@ -39,12 +39,6 @@ func (s streamType) String() string { } } -type connectionID protocol.ConnectionID - -func (c connectionID) String() string { - return fmt.Sprintf("%x", []byte(c)) -} - // category is the qlog event category. type category uint8 diff --git a/server.go b/server.go index cacd0358..ae29ff9b 100644 --- a/server.go +++ b/server.go @@ -477,7 +477,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro if err != nil { return err } - s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(connID)) + s.logger.Debugf("Changing connection ID to %s.", connID) var conn quicConn tracingID := nextConnTracingID() if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { @@ -575,7 +575,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.Token = token if s.logger.Debug() { - s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(srcConnID)) + s.logger.Debugf("Changing connection ID to %s.", srcConnID) s.logger.Debugf("-> Sending Retry") replyHdr.Log(s.logger) } diff --git a/server_test.go b/server_test.go index 672f17d3..1cc68a97 100644 --- a/server_test.go +++ b/server_test.go @@ -71,7 +71,7 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: destConnID, Version: protocol.VersionTLS, } @@ -82,11 +82,11 @@ var _ = Describe("Server", func() { } getInitialWithRandomDestConnID := func() *receivedPacket { - destConnID := make([]byte, 10) - _, err := rand.Read(destConnID) + b := make([]byte, 10) + _, err := rand.Read(b) Expect(err).ToNot(HaveOccurred()) - return getInitial(destConnID) + return getInitial(protocol.ParseConnectionID(b)) } parseHeader := func(data []byte) *wire.Header { @@ -204,7 +204,7 @@ var _ = Describe("Server", func() { p := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Version: serv.config.Versions[0], }, nil) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) @@ -217,7 +217,7 @@ var _ = Describe("Server", func() { p := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: serv.config.Versions[0], }, make([]byte, protocol.MinInitialPacketSize-100), ) @@ -244,15 +244,15 @@ var _ = Describe("Server", func() { raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} retryToken, err := serv.tokenGenerator.NewRetryToken( raddr, - protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), + protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), ) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Version: protocol.VersionTLS, Token: retryToken, } @@ -263,7 +263,7 @@ var _ = Describe("Server", func() { rand.Read(token[:]) var newConnID protocol.ConnectionID - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c @@ -272,7 +272,7 @@ var _ = Describe("Server", func() { fn() return true }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})) conn := NewMockQuicConn(mockCtrl) serv.newConn = func( _ sendConn, @@ -294,8 +294,8 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicConn { Expect(enable0RTT).To(BeFalse()) - Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) - Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) + Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID)) // make sure we're using a server-generated connection ID @@ -325,8 +325,8 @@ var _ = Describe("Server", func() { }) It("sends a Version Negotiation Packet for unsupported versions", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -337,8 +337,8 @@ var _ = Describe("Server", func() { raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} packet.remoteAddr = raddr tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber) { - Expect(src).To(BeEquivalentTo(destConnID)) - Expect(dest).To(BeEquivalentTo(srcConnID)) + Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) + Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { @@ -346,8 +346,8 @@ var _ = Describe("Server", func() { Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) Expect(err).ToNot(HaveOccurred()) - Expect(dest).To(BeEquivalentTo(srcConnID)) - Expect(src).To(BeEquivalentTo(destConnID)) + Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) + Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) return len(b), nil }) @@ -357,8 +357,8 @@ var _ = Describe("Server", func() { It("doesn't send a Version Negotiation packets if sending them is disabled", func() { serv.config.DisableVersionNegotiationPackets = true - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -396,8 +396,8 @@ var _ = Describe("Server", func() { }) It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) p := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -423,8 +423,8 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Version: protocol.VersionTLS, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) @@ -455,8 +455,8 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Version: protocol.VersionTLS, } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) @@ -465,7 +465,7 @@ var _ = Describe("Server", func() { rand.Read(token[:]) var newConnID protocol.ConnectionID - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c @@ -474,7 +474,7 @@ var _ = Describe("Server", func() { fn() return true }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) conn := NewMockQuicConn(mockCtrl) serv.newConn = func( @@ -566,7 +566,7 @@ var _ = Describe("Server", func() { return conn } - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) + p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})) serv.handlePacket(p) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) var wg sync.WaitGroup @@ -575,7 +575,7 @@ var _ = Describe("Server", func() { go func() { defer GinkgoRecover() defer wg.Done() - serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) + serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))) }() } wg.Wait() @@ -614,8 +614,8 @@ var _ = Describe("Server", func() { return conn } - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false) + p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}), gomock.Any(), gomock.Any()).Return(false) Expect(serv.handlePacketImpl(p)).To(BeTrue()) Expect(createdConn).To(BeFalse()) }) @@ -688,7 +688,7 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) conn := NewMockQuicConn(mockCtrl) @@ -772,7 +772,7 @@ var _ = Describe("Server", func() { It("decodes the token from the token field", func() { raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) packet := getPacket(&wire.Header{ IsLongHeader: true, @@ -792,13 +792,13 @@ var _ = Describe("Server", func() { It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -830,14 +830,14 @@ var _ = Describe("Server", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.config.MaxRetryTokenAge = time.Millisecond raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) time.Sleep(2 * time.Millisecond) // make sure the token is expired hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -870,8 +870,8 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -902,8 +902,8 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -923,13 +923,13 @@ var _ = Describe("Server", func() { It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -1031,7 +1031,7 @@ var _ = Describe("Server", func() { tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) cancel() // complete the handshake @@ -1105,7 +1105,7 @@ var _ = Describe("Server", func() { }) serv.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) close(ready) @@ -1174,7 +1174,7 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) conn := NewMockQuicConn(mockCtrl)