diff --git a/conn_id_manager.go b/conn_id_manager.go index dcb36b96..f4bc5c0e 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -53,6 +53,10 @@ func newConnIDManager( } } +func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken *[16]byte) error { + return h.addConnectionID(1, connID, resetToken) +} + func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { if err := h.add(f); err != nil { return err @@ -64,7 +68,7 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { } func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { - // If the NEW_CONNECTION_ID frame is reordered, such that its sequenece number + // If the NEW_CONNECTION_ID frame is reordered, such that its sequence number // was already retired, send the RETIRE_CONNECTION_ID frame immediately. if f.SequenceNumber < h.highestRetired { h.queueControlFrame(&wire.RetireConnectionIDFrame{ @@ -94,34 +98,8 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { return nil } - // insert a new element at the end - if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < f.SequenceNumber { - h.queue.PushBack(utils.NewConnectionID{ - SequenceNumber: f.SequenceNumber, - ConnectionID: f.ConnectionID, - StatelessResetToken: &f.StatelessResetToken, - }) - } else { - // insert a new element somewhere in the middle - for el := h.queue.Front(); el != nil; el = el.Next() { - if el.Value.SequenceNumber == f.SequenceNumber { - if !el.Value.ConnectionID.Equal(f.ConnectionID) { - return fmt.Errorf("received conflicting connection IDs for sequence number %d", f.SequenceNumber) - } - if *el.Value.StatelessResetToken != f.StatelessResetToken { - return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", f.SequenceNumber) - } - break - } - if el.Value.SequenceNumber > f.SequenceNumber { - h.queue.InsertBefore(utils.NewConnectionID{ - SequenceNumber: f.SequenceNumber, - ConnectionID: f.ConnectionID, - StatelessResetToken: &f.StatelessResetToken, - }, el) - break - } - } + if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, &f.StatelessResetToken); err != nil { + return err } // Retire the active connection ID, if necessary. @@ -132,6 +110,39 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { return nil } +func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken *[16]byte) error { + // insert a new element at the end + if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq { + h.queue.PushBack(utils.NewConnectionID{ + SequenceNumber: seq, + ConnectionID: connID, + StatelessResetToken: resetToken, + }) + return nil + } + // insert a new element somewhere in the middle + for el := h.queue.Front(); el != nil; el = el.Next() { + if el.Value.SequenceNumber == seq { + if !el.Value.ConnectionID.Equal(connID) { + return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) + } + if *el.Value.StatelessResetToken != *resetToken { + return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq) + } + break + } + if el.Value.SequenceNumber > seq { + h.queue.InsertBefore(utils.NewConnectionID{ + SequenceNumber: seq, + ConnectionID: connID, + StatelessResetToken: resetToken, + }, el) + break + } + } + return nil +} + func (h *connIDManager) updateConnectionID() { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: h.activeSequenceNumber, diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 43754c8a..49345d19 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -80,13 +80,18 @@ var _ = Describe("Connection ID Manager", func() { }) It("accepts duplicates", func() { - f := &wire.NewConnectionIDFrame{ + f1 := &wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, StatelessResetToken: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } - Expect(m.Add(f)).To(Succeed()) - Expect(m.Add(f)).To(Succeed()) + f2 := &wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, + } + Expect(m.Add(f1)).To(Succeed()) + Expect(m.Add(f2)).To(Succeed()) c1, rt1 := get() Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) Expect(*rt1).To(Equal([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) diff --git a/session.go b/session.go index c7e89eb1..9bf6fa65 100644 --- a/session.go +++ b/session.go @@ -1196,7 +1196,7 @@ func (s *session) processTransportParameters(params *handshake.TransportParamete if params.PreferredAddress != nil { s.logger.Debugf("Server sent preferred_address. Retiring the preferred_address connection ID.") // Retire the connection ID. - s.framer.QueueControlFrame(&wire.RetireConnectionIDFrame{SequenceNumber: 1}) + s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, ¶ms.PreferredAddress.StatelessResetToken) } // On the server side, the early session is ready as soon as we processed // the client's transport parameters. diff --git a/session_test.go b/session_test.go index 5fe41ef3..0453339a 100644 --- a/session_test.go +++ b/session_test.go @@ -1955,20 +1955,26 @@ var _ = Describe("Client Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("immediately retires the preferred_address connection ID", func() { + It("uses the preferred_address connection ID", func() { params := &handshake.TransportParameters{ PreferredAddress: &handshake.PreferredAddress{ - IPv4: net.IPv4(127, 0, 0, 1), - IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + IPv4: net.IPv4(127, 0, 0, 1), + IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + StatelessResetToken: [16]byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, }, } packer.EXPECT().HandleTransportParameters(gomock.Any()) packer.EXPECT().PackCoalescedPacket().MaxTimes(1) sess.processTransportParameters(params) + // make sure the connection ID is not retired cf, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) - Expect(cf).To(HaveLen(1)) - Expect(cf[0].Frame).To(Equal(&wire.RetireConnectionIDFrame{SequenceNumber: 1})) + Expect(cf).To(BeEmpty()) + sessionRunner.EXPECT().AddResetToken([16]byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, sess) + Expect(sess.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + // shut down + sessionRunner.EXPECT().RemoveResetToken([16]byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) + expectClose() }) It("uses the minimum of the peers' idle timeouts", func() {