diff --git a/conn_id_manager.go b/conn_id_manager.go index 1bed9928..7c2455c9 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -11,28 +11,34 @@ import ( type connIDManager struct { queue utils.NewConnectionIDList + activeSequenceNumber uint64 + activeConnectionID protocol.ConnectionID + queueControlFrame func(wire.Frame) } -func newConnIDManager(queueControlFrame func(wire.Frame)) *connIDManager { - return &connIDManager{queueControlFrame: queueControlFrame} +func newConnIDManager( + initialDestConnID protocol.ConnectionID, + queueControlFrame func(wire.Frame), +) *connIDManager { + h := &connIDManager{queueControlFrame: queueControlFrame} + h.activeConnectionID = initialDestConnID + return h } func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { if err := h.add(f); err != nil { return err } - if h.queue.Len() > protocol.MaxActiveConnectionIDs { - // delete the first connection ID in the queue - val := h.queue.Remove(h.queue.Front()) - h.queueControlFrame(&wire.RetireConnectionIDFrame{ - SequenceNumber: val.SequenceNumber, - }) + if h.queue.Len() >= protocol.MaxActiveConnectionIDs { + h.updateConnectionID() } return nil } func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { + // Retire elements in the queue. + // Doesn't retire the active connection ID. var next *utils.NewConnectionIDElement for el := h.queue.Front(); el != nil; el = next { if el.Value.SequenceNumber >= f.RetirePriorTo { @@ -52,27 +58,55 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { ConnectionID: f.ConnectionID, StatelessResetToken: &f.StatelessResetToken, }) - return nil - } - // insert a new element somewhere in the middle - for el := h.queue.Front(); el != nil; el = el.Next() { - if el.Value.SequenceNumber == f.SequenceNumber { - if !el.Value.ConnectionID.Equal(f.ConnectionID) { - return fmt.Errorf("received conflicting connection IDs for sequence number %d", f.SequenceNumber) + } else { + // insert a new element somewhere in the middle + for el := h.queue.Front(); el != nil; el = el.Next() { + if el.Value.SequenceNumber == f.SequenceNumber { + if !el.Value.ConnectionID.Equal(f.ConnectionID) { + return fmt.Errorf("received conflicting connection IDs for sequence number %d", f.SequenceNumber) + } + if *el.Value.StatelessResetToken != f.StatelessResetToken { + return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", f.SequenceNumber) + } + break } - if *el.Value.StatelessResetToken != f.StatelessResetToken { - return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", f.SequenceNumber) + if el.Value.SequenceNumber > f.SequenceNumber { + h.queue.InsertBefore(utils.NewConnectionID{ + SequenceNumber: f.SequenceNumber, + ConnectionID: f.ConnectionID, + StatelessResetToken: &f.StatelessResetToken, + }, el) + break } - return nil - } - if el.Value.SequenceNumber > f.SequenceNumber { - h.queue.InsertBefore(utils.NewConnectionID{ - SequenceNumber: f.SequenceNumber, - ConnectionID: f.ConnectionID, - StatelessResetToken: &f.StatelessResetToken, - }, el) - return nil } } - panic("should have processed NEW_CONNECTION_ID frame") + + // Retire the active connection ID, if necessary. + if h.activeSequenceNumber < f.RetirePriorTo { + // The queue is guaranteed to have at least one element at this point. + h.updateConnectionID() + } + return nil +} + +func (h *connIDManager) updateConnectionID() { + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: h.activeSequenceNumber, + }) + front := h.queue.Remove(h.queue.Front()) + h.activeSequenceNumber = front.SequenceNumber + h.activeConnectionID = front.ConnectionID +} + +// is called when the server performs a Retry +// and when the server changes the connection ID in the first Initial sent +func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { + if h.activeSequenceNumber != 0 { + panic("expected first connection ID to have sequence number 0") + } + h.activeConnectionID = newConnID +} + +func (h *connIDManager) Get() protocol.ConnectionID { + return h.activeConnectionID } diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 6419ef1a..fa9fd212 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -12,10 +12,11 @@ var _ = Describe("Connection ID Manager", func() { m *connIDManager frameQueue []wire.Frame ) + initialConnID := protocol.ConnectionID{1, 1, 1, 1} BeforeEach(func() { frameQueue = nil - m = newConnIDManager(func(f wire.Frame) { + m = newConnIDManager(initialConnID, func(f wire.Frame) { frameQueue = append(frameQueue, f) }) }) @@ -28,10 +29,13 @@ var _ = Describe("Connection ID Manager", func() { return val.ConnectionID, val.StatelessResetToken } - It("returns nil if empty", func() { - c, rt := get() - Expect(c).To(BeNil()) - Expect(rt).To(BeNil()) + It("returns the initial connection ID", func() { + Expect(m.Get()).To(Equal(initialConnID)) + }) + + It("changes the initial connection ID", func() { + m.ChangeInitialConnID(protocol.ConnectionID{1, 2, 3, 4, 5}) + Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) }) It("adds and gets connection IDs", func() { @@ -111,26 +115,28 @@ var _ = Describe("Connection ID Manager", func() { SequenceNumber: 17, ConnectionID: protocol.ConnectionID{3, 4, 5, 6}, })).To(Succeed()) - Expect(frameQueue).To(HaveLen(2)) + Expect(frameQueue).To(HaveLen(3)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(10)) Expect(frameQueue[1].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(13)) - c, _ := get() - Expect(c).To(Equal(protocol.ConnectionID{3, 4, 5, 6})) + Expect(frameQueue[2].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) + Expect(m.Get()).To(Equal(protocol.ConnectionID{3, 4, 5, 6})) }) It("retires old connection IDs when the peer sends too many new ones", func() { - for i := uint8(0); i < protocol.MaxActiveConnectionIDs; i++ { + for i := uint8(1); i <= protocol.MaxActiveConnectionIDs; i++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(i), ConnectionID: protocol.ConnectionID{i, i, i, i}, })).To(Succeed()) } - Expect(frameQueue).To(BeEmpty()) + Expect(frameQueue).To(HaveLen(1)) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) + frameQueue = nil Expect(m.Add(&wire.NewConnectionIDFrame{ - SequenceNumber: protocol.MaxActiveConnectionIDs, + SequenceNumber: protocol.MaxActiveConnectionIDs + 1, ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) - Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(0)) + Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(1)) }) }) diff --git a/mock_packer_test.go b/mock_packer_test.go index b96cfb68..e2bf3507 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -9,7 +9,6 @@ import ( gomock "github.com/golang/mock/gomock" handshake "github.com/lucas-clemente/quic-go/internal/handshake" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" wire "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -36,18 +35,6 @@ func (m *MockPacker) EXPECT() *MockPackerMockRecorder { return m.recorder } -// ChangeDestConnectionID mocks base method -func (m *MockPacker) ChangeDestConnectionID(arg0 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ChangeDestConnectionID", arg0) -} - -// ChangeDestConnectionID indicates an expected call of ChangeDestConnectionID -func (mr *MockPackerMockRecorder) ChangeDestConnectionID(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeDestConnectionID", reflect.TypeOf((*MockPacker)(nil).ChangeDestConnectionID), arg0) -} - // HandleTransportParameters mocks base method func (m *MockPacker) HandleTransportParameters(arg0 *handshake.TransportParameters) { m.ctrl.T.Helper() diff --git a/packet_packer.go b/packet_packer.go index 9b1b3803..357e77dc 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -21,7 +21,6 @@ type packer interface { HandleTransportParameters(*handshake.TransportParameters) SetToken([]byte) - ChangeDestConnectionID(protocol.ConnectionID) } type sealer interface { @@ -128,8 +127,8 @@ type ackFrameSource interface { } type packetPacker struct { - destConnID protocol.ConnectionID - srcConnID protocol.ConnectionID + srcConnID protocol.ConnectionID + getDestConnID func() protocol.ConnectionID perspective protocol.Perspective version protocol.VersionNumber @@ -155,8 +154,8 @@ type packetPacker struct { var _ packer = &packetPacker{} func newPacketPacker( - destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + getDestConnID func() protocol.ConnectionID, initialStream cryptoStream, handshakeStream cryptoStream, packetNumberManager packetNumberManager, @@ -170,7 +169,7 @@ func newPacketPacker( ) *packetPacker { return &packetPacker{ cryptoSetup: cryptoSetup, - destConnID: destConnID, + getDestConnID: getDestConnID, srcConnID: srcConnID, initialStream: initialStream, handshakeStream: handshakeStream, @@ -432,7 +431,7 @@ func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHea hdr := &wire.ExtendedHeader{} hdr.PacketNumber = pn hdr.PacketNumberLen = pnLen - hdr.DestConnectionID = p.destConnID + hdr.DestConnectionID = p.getDestConnID() hdr.KeyPhase = kp return hdr } @@ -442,7 +441,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex hdr := &wire.ExtendedHeader{} hdr.PacketNumber = pn hdr.PacketNumberLen = pnLen - hdr.DestConnectionID = p.destConnID + hdr.DestConnectionID = p.getDestConnID() switch encLevel { case protocol.EncryptionInitial: @@ -550,10 +549,6 @@ func (p *packetPacker) writeAndSealPacketWithPadding( }, nil } -func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) { - p.destConnID = connID -} - func (p *packetPacker) SetToken(token []byte) { p.token = token } diff --git a/packet_packer_test.go b/packet_packer_test.go index c1b9a58c..d8a95002 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -77,7 +77,7 @@ var _ = Describe("Packet packer", func() { packer = newPacketPacker( protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + func() protocol.ConnectionID { return protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} }, initialStream, handshakeStream, pnManager, @@ -126,28 +126,12 @@ var _ = Describe("Packet packer", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} packer.srcConnID = srcConnID - packer.destConnID = destConnID + packer.getDestConnID = func() protocol.ConnectionID { return destConnID } h := packer.getLongHeader(protocol.EncryptionHandshake) Expect(h.SrcConnectionID).To(Equal(srcConnID)) Expect(h.DestConnectionID).To(Equal(destConnID)) }) - It("changes the destination connection ID", func() { - pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) - srcConnID := protocol.ConnectionID{1, 1, 1, 1, 1, 1, 1, 1} - packer.srcConnID = srcConnID - dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - dest2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - packer.ChangeDestConnectionID(dest1) - h := packer.getLongHeader(protocol.EncryptionInitial) - Expect(h.SrcConnectionID).To(Equal(srcConnID)) - Expect(h.DestConnectionID).To(Equal(dest1)) - packer.ChangeDestConnectionID(dest2) - h = packer.getLongHeader(protocol.EncryptionInitial) - Expect(h.SrcConnectionID).To(Equal(srcConnID)) - Expect(h.DestConnectionID).To(Equal(dest2)) - }) - It("gets a short header", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4) h := packer.getShortHeader(protocol.KeyPhaseOne) @@ -397,7 +381,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()] - hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.destConnID)) + hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.getDestConnID())) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(packet.raw) extHdr, err := hdr.ParseExtended(r, packer.version) diff --git a/session.go b/session.go index 09af228d..5f8b45e9 100644 --- a/session.go +++ b/session.go @@ -104,7 +104,6 @@ var errCloseForRecreating = errors.New("closing session in order to recreate it" type session struct { sessionRunner sessionRunner - destConnID protocol.ConnectionID origDestConnID protocol.ConnectionID // if the server sends a Retry, this is the connection ID we used initially srcConnID protocol.ConnectionID @@ -201,13 +200,13 @@ var newSession = func( sessionRunner: runner, config: conf, srcConnID: srcConnID, - destConnID: destConnID, tokenGenerator: tokenGenerator, perspective: protocol.PerspectiveServer, handshakeCompleteChan: make(chan struct{}), logger: logger, version: v, } + s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame) s.preSetup() s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger) initialStream := newCryptoStream() @@ -231,9 +230,10 @@ var newSession = func( logger, ) s.cryptoStreamHandler = cs + s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame) s.packer = newPacketPacker( - s.destConnID, s.srcConnID, + s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, @@ -269,13 +269,13 @@ var newClientSession = func( sessionRunner: runner, config: conf, srcConnID: srcConnID, - destConnID: destConnID, perspective: protocol.PerspectiveClient, handshakeCompleteChan: make(chan struct{}), logger: logger, initialVersion: initialVersion, version: v, } + s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame) s.preSetup() s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger) initialStream := newCryptoStream() @@ -285,7 +285,7 @@ var newClientSession = func( initialStream, handshakeStream, oneRTTStream, - s.destConnID, + destConnID, conn.RemoteAddr(), params, &handshakeRunner{ @@ -303,8 +303,8 @@ var newClientSession = func( s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) s.unpacker = newPacketUnpacker(cs, s.version) s.packer = newPacketPacker( - s.destConnID, s.srcConnID, + s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, @@ -333,7 +333,6 @@ func (s *session) preSetup() { s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue(s.version) s.frameParser = wire.NewFrameParser(s.version) - s.connIDManager = newConnIDManager(s.queueControlFrame) s.rttStats = &congestion.RTTStats{} s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( @@ -601,8 +600,9 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { - s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, s.destConnID) + destConnID := s.connIDManager.Get() + if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(destConnID) { + s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, destConnID) return false } // drop 0-RTT packets @@ -652,11 +652,12 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R return false } (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - if !hdr.OrigDestConnectionID.Equal(s.destConnID) { - s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, s.destConnID) + destConnID := s.connIDManager.Get() + if !hdr.OrigDestConnectionID.Equal(destConnID) { + s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, destConnID) return false } - if hdr.SrcConnectionID.Equal(s.destConnID) { + if hdr.SrcConnectionID.Equal(destConnID) { s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") return false } @@ -668,16 +669,16 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R } s.logger.Debugf("<- Received Retry") s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) - s.origDestConnID = s.destConnID - s.destConnID = hdr.SrcConnectionID + s.origDestConnID = destConnID + newDestConnID := hdr.SrcConnectionID s.receivedRetry = true if err := s.sentPacketHandler.ResetForRetry(); err != nil { s.closeLocal(err) return false } - s.cryptoStreamHandler.ChangeConnectionID(s.destConnID) + s.cryptoStreamHandler.ChangeConnectionID(newDestConnID) s.packer.SetToken(hdr.Token) - s.packer.ChangeDestConnectionID(s.destConnID) + s.connIDManager.ChangeInitialConnID(newDestConnID) s.scheduleSending() return true } @@ -688,10 +689,9 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time } // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.destConnID) { + if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.connIDManager.Get()) { s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", packet.hdr.SrcConnectionID) - s.destConnID = packet.hdr.SrcConnectionID - s.packer.ChangeDestConnectionID(s.destConnID) + s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } s.receivedFirstPacket = true @@ -927,9 +927,9 @@ func (s *session) destroy(e error) { func (s *session) destroyImpl(e error) { s.closeOnce.Do(func() { if nerr, ok := e.(net.Error); ok && nerr.Timeout() { - s.logger.Errorf("Destroying session %s: %s", s.destConnID, e) + s.logger.Errorf("Destroying session %s: %s", s.connIDManager.Get(), e) } else { - s.logger.Errorf("Destroying session %s with error: %s", s.destConnID, e) + s.logger.Errorf("Destroying session %s with error: %s", s.connIDManager.Get(), e) } s.sessionRunner.Remove(s.srcConnID) s.closeChan <- closeError{err: e, sendClose: false, remote: false} diff --git a/session_test.go b/session_test.go index 2d216a3d..7d49e636 100644 --- a/session_test.go +++ b/session_test.go @@ -79,6 +79,7 @@ var _ = Describe("Session", func() { packer *MockPacker cryptoSetup *mocks.MockCryptoSetup ) + destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} getPacket := func(pn protocol.PacketNumber) *packedPacket { buffer := getPacketBuffer() @@ -110,7 +111,7 @@ var _ = Describe("Session", func() { mconn, sessionRunner, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + destConnID, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, populateServerConfig(&Config{}), nil, // tls.Config @@ -307,7 +308,7 @@ var _ = Describe("Session", func() { SequenceNumber: 10, ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, }, 1, protocol.Encryption1RTT)).To(Succeed()) - Expect(sess.connIDManager.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(sess.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) }) It("handles PING frames", func() { @@ -673,7 +674,7 @@ var _ = Describe("Session", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: sess.destConnID, + DestConnectionID: destConnID, SrcConnectionID: sess.srcConnID, Length: 1, Version: sess.version, @@ -685,7 +686,7 @@ var _ = Describe("Session", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: sess.destConnID, + DestConnectionID: destConnID, SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, Length: 1, Version: sess.version, @@ -711,7 +712,7 @@ var _ = Describe("Session", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: sess.destConnID, + DestConnectionID: destConnID, SrcConnectionID: sess.srcConnID, Length: 1, Version: sess.version, @@ -752,7 +753,7 @@ var _ = Describe("Session", func() { IsLongHeader: true, Type: protocol.PacketTypeHandshake, DestConnectionID: connID, - SrcConnectionID: sess.destConnID, + SrcConnectionID: destConnID, Version: protocol.VersionTLS, Length: length, }, @@ -1507,6 +1508,7 @@ var _ = Describe("Client Session", func() { tlsConf *tls.Config quicConf *Config ) + destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { buf := &bytes.Buffer{} @@ -1539,7 +1541,7 @@ var _ = Describe("Client Session", func() { sess = newClientSession( mconn, sessionRunner, - protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + destConnID, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, quicConf, tlsConf, @@ -1571,7 +1573,6 @@ var _ = Describe("Client Session", func() { sess.run() }() newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} - packer.EXPECT().ChangeDestConnectionID(newConnID) Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, @@ -1627,7 +1628,6 @@ var _ = Describe("Client Session", func() { It("handles Retry packets", func() { cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) packer.EXPECT().SetToken([]byte("foobar")) - packer.EXPECT().ChangeDestConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeTrue()) }) @@ -1637,7 +1637,7 @@ var _ = Describe("Client Session", func() { }) It("ignores Retry packets if the server didn't change the connection ID", func() { - validRetryHdr.SrcConnectionID = sess.destConnID + validRetryHdr.SrcConnectionID = destConnID Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse()) }) @@ -1724,7 +1724,7 @@ var _ = Describe("Client Session", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: sess.destConnID, + DestConnectionID: destConnID, SrcConnectionID: sess.srcConnID, Length: 1, Version: sess.version, @@ -1736,7 +1736,7 @@ var _ = Describe("Client Session", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: sess.destConnID, + DestConnectionID: destConnID, SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, Length: 1, Version: sess.version, @@ -1746,7 +1746,6 @@ var _ = Describe("Client Session", func() { } Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) // Send one packet, which might change the connection ID. - packer.EXPECT().ChangeDestConnectionID(sess.srcConnID).MaxTimes(1) // only EXPECT one call to the unpacker unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.EncryptionInitial, @@ -1762,7 +1761,7 @@ var _ = Describe("Client Session", func() { // the connection to immediately break down It("fails on Initial-level ACK for unsent packet", func() { ackFrame := testutils.ComposeAckFrame(0, 0) - initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{ackFrame}) + initialPacket := testutils.ComposeInitialPacket(destConnID, sess.srcConnID, sess.version, destConnID, []wire.Frame{ackFrame}) Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) @@ -1771,7 +1770,7 @@ var _ = Describe("Client Session", func() { It("fails on Initial-level CONNECTION_CLOSE frame", func() { sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()) connCloseFrame := testutils.ComposeConnCloseFrame() - initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{connCloseFrame}) + initialPacket := testutils.ComposeInitialPacket(destConnID, sess.srcConnID, sess.version, destConnID, []wire.Frame{connCloseFrame}) Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) }) @@ -1781,10 +1780,9 @@ var _ = Describe("Client Session", func() { newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) packer.EXPECT().SetToken([]byte("foobar")) - packer.EXPECT().ChangeDestConnectionID(newSrcConnID) - sess.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, sess.destConnID, sess.destConnID, []byte("foobar"), sess.version))) - initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, nil) + sess.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), sess.version))) + initialPacket := testutils.ComposeInitialPacket(sess.connIDManager.Get(), sess.srcConnID, sess.version, sess.connIDManager.Get(), nil) Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) })