mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
use the connection ID manager to save the destination connection ID
This commit is contained in:
parent
a321f9faa6
commit
772ffd3d20
7 changed files with 125 additions and 121 deletions
|
@ -11,28 +11,34 @@ import (
|
|||
type connIDManager struct {
|
||||
queue utils.NewConnectionIDList
|
||||
|
||||
activeSequenceNumber uint64
|
||||
activeConnectionID protocol.ConnectionID
|
||||
|
||||
queueControlFrame func(wire.Frame)
|
||||
}
|
||||
|
||||
func newConnIDManager(queueControlFrame func(wire.Frame)) *connIDManager {
|
||||
return &connIDManager{queueControlFrame: queueControlFrame}
|
||||
func newConnIDManager(
|
||||
initialDestConnID protocol.ConnectionID,
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *connIDManager {
|
||||
h := &connIDManager{queueControlFrame: queueControlFrame}
|
||||
h.activeConnectionID = initialDestConnID
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
|
||||
if err := h.add(f); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.queue.Len() > protocol.MaxActiveConnectionIDs {
|
||||
// delete the first connection ID in the queue
|
||||
val := h.queue.Remove(h.queue.Front())
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: val.SequenceNumber,
|
||||
})
|
||||
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
|
||||
h.updateConnectionID()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
|
||||
// Retire elements in the queue.
|
||||
// Doesn't retire the active connection ID.
|
||||
var next *utils.NewConnectionIDElement
|
||||
for el := h.queue.Front(); el != nil; el = next {
|
||||
if el.Value.SequenceNumber >= f.RetirePriorTo {
|
||||
|
@ -52,8 +58,7 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
|
|||
ConnectionID: f.ConnectionID,
|
||||
StatelessResetToken: &f.StatelessResetToken,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
// insert a new element somewhere in the middle
|
||||
for el := h.queue.Front(); el != nil; el = el.Next() {
|
||||
if el.Value.SequenceNumber == f.SequenceNumber {
|
||||
|
@ -63,7 +68,7 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
|
|||
if *el.Value.StatelessResetToken != f.StatelessResetToken {
|
||||
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", f.SequenceNumber)
|
||||
}
|
||||
return nil
|
||||
break
|
||||
}
|
||||
if el.Value.SequenceNumber > f.SequenceNumber {
|
||||
h.queue.InsertBefore(utils.NewConnectionID{
|
||||
|
@ -71,8 +76,37 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
|
|||
ConnectionID: f.ConnectionID,
|
||||
StatelessResetToken: &f.StatelessResetToken,
|
||||
}, el)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Retire the active connection ID, if necessary.
|
||||
if h.activeSequenceNumber < f.RetirePriorTo {
|
||||
// The queue is guaranteed to have at least one element at this point.
|
||||
h.updateConnectionID()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
panic("should have processed NEW_CONNECTION_ID frame")
|
||||
}
|
||||
|
||||
func (h *connIDManager) updateConnectionID() {
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: h.activeSequenceNumber,
|
||||
})
|
||||
front := h.queue.Remove(h.queue.Front())
|
||||
h.activeSequenceNumber = front.SequenceNumber
|
||||
h.activeConnectionID = front.ConnectionID
|
||||
}
|
||||
|
||||
// is called when the server performs a Retry
|
||||
// and when the server changes the connection ID in the first Initial sent
|
||||
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
|
||||
if h.activeSequenceNumber != 0 {
|
||||
panic("expected first connection ID to have sequence number 0")
|
||||
}
|
||||
h.activeConnectionID = newConnID
|
||||
}
|
||||
|
||||
func (h *connIDManager) Get() protocol.ConnectionID {
|
||||
return h.activeConnectionID
|
||||
}
|
||||
|
|
|
@ -12,10 +12,11 @@ var _ = Describe("Connection ID Manager", func() {
|
|||
m *connIDManager
|
||||
frameQueue []wire.Frame
|
||||
)
|
||||
initialConnID := protocol.ConnectionID{1, 1, 1, 1}
|
||||
|
||||
BeforeEach(func() {
|
||||
frameQueue = nil
|
||||
m = newConnIDManager(func(f wire.Frame) {
|
||||
m = newConnIDManager(initialConnID, func(f wire.Frame) {
|
||||
frameQueue = append(frameQueue, f)
|
||||
})
|
||||
})
|
||||
|
@ -28,10 +29,13 @@ var _ = Describe("Connection ID Manager", func() {
|
|||
return val.ConnectionID, val.StatelessResetToken
|
||||
}
|
||||
|
||||
It("returns nil if empty", func() {
|
||||
c, rt := get()
|
||||
Expect(c).To(BeNil())
|
||||
Expect(rt).To(BeNil())
|
||||
It("returns the initial connection ID", func() {
|
||||
Expect(m.Get()).To(Equal(initialConnID))
|
||||
})
|
||||
|
||||
It("changes the initial connection ID", func() {
|
||||
m.ChangeInitialConnID(protocol.ConnectionID{1, 2, 3, 4, 5})
|
||||
Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5}))
|
||||
})
|
||||
|
||||
It("adds and gets connection IDs", func() {
|
||||
|
@ -111,26 +115,28 @@ var _ = Describe("Connection ID Manager", func() {
|
|||
SequenceNumber: 17,
|
||||
ConnectionID: protocol.ConnectionID{3, 4, 5, 6},
|
||||
})).To(Succeed())
|
||||
Expect(frameQueue).To(HaveLen(2))
|
||||
Expect(frameQueue).To(HaveLen(3))
|
||||
Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(10))
|
||||
Expect(frameQueue[1].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(13))
|
||||
c, _ := get()
|
||||
Expect(c).To(Equal(protocol.ConnectionID{3, 4, 5, 6}))
|
||||
Expect(frameQueue[2].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero())
|
||||
Expect(m.Get()).To(Equal(protocol.ConnectionID{3, 4, 5, 6}))
|
||||
})
|
||||
|
||||
It("retires old connection IDs when the peer sends too many new ones", func() {
|
||||
for i := uint8(0); i < protocol.MaxActiveConnectionIDs; i++ {
|
||||
for i := uint8(1); i <= protocol.MaxActiveConnectionIDs; i++ {
|
||||
Expect(m.Add(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: uint64(i),
|
||||
ConnectionID: protocol.ConnectionID{i, i, i, i},
|
||||
})).To(Succeed())
|
||||
}
|
||||
Expect(frameQueue).To(BeEmpty())
|
||||
Expect(frameQueue).To(HaveLen(1))
|
||||
Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero())
|
||||
frameQueue = nil
|
||||
Expect(m.Add(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: protocol.MaxActiveConnectionIDs,
|
||||
SequenceNumber: protocol.MaxActiveConnectionIDs + 1,
|
||||
ConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
})).To(Succeed())
|
||||
Expect(frameQueue).To(HaveLen(1))
|
||||
Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(0))
|
||||
Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(1))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
wire "github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
|
@ -36,18 +35,6 @@ func (m *MockPacker) EXPECT() *MockPackerMockRecorder {
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
// ChangeDestConnectionID mocks base method
|
||||
func (m *MockPacker) ChangeDestConnectionID(arg0 protocol.ConnectionID) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ChangeDestConnectionID", arg0)
|
||||
}
|
||||
|
||||
// ChangeDestConnectionID indicates an expected call of ChangeDestConnectionID
|
||||
func (mr *MockPackerMockRecorder) ChangeDestConnectionID(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeDestConnectionID", reflect.TypeOf((*MockPacker)(nil).ChangeDestConnectionID), arg0)
|
||||
}
|
||||
|
||||
// HandleTransportParameters mocks base method
|
||||
func (m *MockPacker) HandleTransportParameters(arg0 *handshake.TransportParameters) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -21,7 +21,6 @@ type packer interface {
|
|||
|
||||
HandleTransportParameters(*handshake.TransportParameters)
|
||||
SetToken([]byte)
|
||||
ChangeDestConnectionID(protocol.ConnectionID)
|
||||
}
|
||||
|
||||
type sealer interface {
|
||||
|
@ -128,8 +127,8 @@ type ackFrameSource interface {
|
|||
}
|
||||
|
||||
type packetPacker struct {
|
||||
destConnID protocol.ConnectionID
|
||||
srcConnID protocol.ConnectionID
|
||||
getDestConnID func() protocol.ConnectionID
|
||||
|
||||
perspective protocol.Perspective
|
||||
version protocol.VersionNumber
|
||||
|
@ -155,8 +154,8 @@ type packetPacker struct {
|
|||
var _ packer = &packetPacker{}
|
||||
|
||||
func newPacketPacker(
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
getDestConnID func() protocol.ConnectionID,
|
||||
initialStream cryptoStream,
|
||||
handshakeStream cryptoStream,
|
||||
packetNumberManager packetNumberManager,
|
||||
|
@ -170,7 +169,7 @@ func newPacketPacker(
|
|||
) *packetPacker {
|
||||
return &packetPacker{
|
||||
cryptoSetup: cryptoSetup,
|
||||
destConnID: destConnID,
|
||||
getDestConnID: getDestConnID,
|
||||
srcConnID: srcConnID,
|
||||
initialStream: initialStream,
|
||||
handshakeStream: handshakeStream,
|
||||
|
@ -432,7 +431,7 @@ func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHea
|
|||
hdr := &wire.ExtendedHeader{}
|
||||
hdr.PacketNumber = pn
|
||||
hdr.PacketNumberLen = pnLen
|
||||
hdr.DestConnectionID = p.destConnID
|
||||
hdr.DestConnectionID = p.getDestConnID()
|
||||
hdr.KeyPhase = kp
|
||||
return hdr
|
||||
}
|
||||
|
@ -442,7 +441,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
|
|||
hdr := &wire.ExtendedHeader{}
|
||||
hdr.PacketNumber = pn
|
||||
hdr.PacketNumberLen = pnLen
|
||||
hdr.DestConnectionID = p.destConnID
|
||||
hdr.DestConnectionID = p.getDestConnID()
|
||||
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
|
@ -550,10 +549,6 @@ func (p *packetPacker) writeAndSealPacketWithPadding(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
|
||||
p.destConnID = connID
|
||||
}
|
||||
|
||||
func (p *packetPacker) SetToken(token []byte) {
|
||||
p.token = token
|
||||
}
|
||||
|
|
|
@ -77,7 +77,7 @@ var _ = Describe("Packet packer", func() {
|
|||
|
||||
packer = newPacketPacker(
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
func() protocol.ConnectionID { return protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} },
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
pnManager,
|
||||
|
@ -126,28 +126,12 @@ var _ = Describe("Packet packer", func() {
|
|||
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
packer.srcConnID = srcConnID
|
||||
packer.destConnID = destConnID
|
||||
packer.getDestConnID = func() protocol.ConnectionID { return destConnID }
|
||||
h := packer.getLongHeader(protocol.EncryptionHandshake)
|
||||
Expect(h.SrcConnectionID).To(Equal(srcConnID))
|
||||
Expect(h.DestConnectionID).To(Equal(destConnID))
|
||||
})
|
||||
|
||||
It("changes the destination connection ID", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
|
||||
srcConnID := protocol.ConnectionID{1, 1, 1, 1, 1, 1, 1, 1}
|
||||
packer.srcConnID = srcConnID
|
||||
dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
dest2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
packer.ChangeDestConnectionID(dest1)
|
||||
h := packer.getLongHeader(protocol.EncryptionInitial)
|
||||
Expect(h.SrcConnectionID).To(Equal(srcConnID))
|
||||
Expect(h.DestConnectionID).To(Equal(dest1))
|
||||
packer.ChangeDestConnectionID(dest2)
|
||||
h = packer.getLongHeader(protocol.EncryptionInitial)
|
||||
Expect(h.SrcConnectionID).To(Equal(srcConnID))
|
||||
Expect(h.DestConnectionID).To(Equal(dest2))
|
||||
})
|
||||
|
||||
It("gets a short header", func() {
|
||||
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4)
|
||||
h := packer.getShortHeader(protocol.KeyPhaseOne)
|
||||
|
@ -397,7 +381,7 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
// cut off the tag that the mock sealer added
|
||||
packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()]
|
||||
hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.destConnID))
|
||||
hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.getDestConnID()))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
r := bytes.NewReader(packet.raw)
|
||||
extHdr, err := hdr.ParseExtended(r, packer.version)
|
||||
|
|
42
session.go
42
session.go
|
@ -104,7 +104,6 @@ var errCloseForRecreating = errors.New("closing session in order to recreate it"
|
|||
type session struct {
|
||||
sessionRunner sessionRunner
|
||||
|
||||
destConnID protocol.ConnectionID
|
||||
origDestConnID protocol.ConnectionID // if the server sends a Retry, this is the connection ID we used initially
|
||||
srcConnID protocol.ConnectionID
|
||||
|
||||
|
@ -201,13 +200,13 @@ var newSession = func(
|
|||
sessionRunner: runner,
|
||||
config: conf,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
tokenGenerator: tokenGenerator,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
handshakeCompleteChan: make(chan struct{}),
|
||||
logger: logger,
|
||||
version: v,
|
||||
}
|
||||
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
|
||||
s.preSetup()
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger)
|
||||
initialStream := newCryptoStream()
|
||||
|
@ -231,9 +230,10 @@ var newSession = func(
|
|||
logger,
|
||||
)
|
||||
s.cryptoStreamHandler = cs
|
||||
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
|
||||
s.packer = newPacketPacker(
|
||||
s.destConnID,
|
||||
s.srcConnID,
|
||||
s.connIDManager.Get,
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
s.sentPacketHandler,
|
||||
|
@ -269,13 +269,13 @@ var newClientSession = func(
|
|||
sessionRunner: runner,
|
||||
config: conf,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
handshakeCompleteChan: make(chan struct{}),
|
||||
logger: logger,
|
||||
initialVersion: initialVersion,
|
||||
version: v,
|
||||
}
|
||||
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
|
||||
s.preSetup()
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger)
|
||||
initialStream := newCryptoStream()
|
||||
|
@ -285,7 +285,7 @@ var newClientSession = func(
|
|||
initialStream,
|
||||
handshakeStream,
|
||||
oneRTTStream,
|
||||
s.destConnID,
|
||||
destConnID,
|
||||
conn.RemoteAddr(),
|
||||
params,
|
||||
&handshakeRunner{
|
||||
|
@ -303,8 +303,8 @@ var newClientSession = func(
|
|||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
|
||||
s.unpacker = newPacketUnpacker(cs, s.version)
|
||||
s.packer = newPacketPacker(
|
||||
s.destConnID,
|
||||
s.srcConnID,
|
||||
s.connIDManager.Get,
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
s.sentPacketHandler,
|
||||
|
@ -333,7 +333,6 @@ func (s *session) preSetup() {
|
|||
s.sendQueue = newSendQueue(s.conn)
|
||||
s.retransmissionQueue = newRetransmissionQueue(s.version)
|
||||
s.frameParser = wire.NewFrameParser(s.version)
|
||||
s.connIDManager = newConnIDManager(s.queueControlFrame)
|
||||
s.rttStats = &congestion.RTTStats{}
|
||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
|
@ -601,8 +600,9 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /
|
|||
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
// After this, all packets with a different source connection have to be ignored.
|
||||
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, s.destConnID)
|
||||
destConnID := s.connIDManager.Get()
|
||||
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(destConnID) {
|
||||
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, destConnID)
|
||||
return false
|
||||
}
|
||||
// drop 0-RTT packets
|
||||
|
@ -652,11 +652,12 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R
|
|||
return false
|
||||
}
|
||||
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
|
||||
if !hdr.OrigDestConnectionID.Equal(s.destConnID) {
|
||||
s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, s.destConnID)
|
||||
destConnID := s.connIDManager.Get()
|
||||
if !hdr.OrigDestConnectionID.Equal(destConnID) {
|
||||
s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, destConnID)
|
||||
return false
|
||||
}
|
||||
if hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||
if hdr.SrcConnectionID.Equal(destConnID) {
|
||||
s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
|
||||
return false
|
||||
}
|
||||
|
@ -668,16 +669,16 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R
|
|||
}
|
||||
s.logger.Debugf("<- Received Retry")
|
||||
s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID)
|
||||
s.origDestConnID = s.destConnID
|
||||
s.destConnID = hdr.SrcConnectionID
|
||||
s.origDestConnID = destConnID
|
||||
newDestConnID := hdr.SrcConnectionID
|
||||
s.receivedRetry = true
|
||||
if err := s.sentPacketHandler.ResetForRetry(); err != nil {
|
||||
s.closeLocal(err)
|
||||
return false
|
||||
}
|
||||
s.cryptoStreamHandler.ChangeConnectionID(s.destConnID)
|
||||
s.cryptoStreamHandler.ChangeConnectionID(newDestConnID)
|
||||
s.packer.SetToken(hdr.Token)
|
||||
s.packer.ChangeDestConnectionID(s.destConnID)
|
||||
s.connIDManager.ChangeInitialConnID(newDestConnID)
|
||||
s.scheduleSending()
|
||||
return true
|
||||
}
|
||||
|
@ -688,10 +689,9 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time
|
|||
}
|
||||
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.destConnID) {
|
||||
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.connIDManager.Get()) {
|
||||
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", packet.hdr.SrcConnectionID)
|
||||
s.destConnID = packet.hdr.SrcConnectionID
|
||||
s.packer.ChangeDestConnectionID(s.destConnID)
|
||||
s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
|
||||
}
|
||||
|
||||
s.receivedFirstPacket = true
|
||||
|
@ -927,9 +927,9 @@ func (s *session) destroy(e error) {
|
|||
func (s *session) destroyImpl(e error) {
|
||||
s.closeOnce.Do(func() {
|
||||
if nerr, ok := e.(net.Error); ok && nerr.Timeout() {
|
||||
s.logger.Errorf("Destroying session %s: %s", s.destConnID, e)
|
||||
s.logger.Errorf("Destroying session %s: %s", s.connIDManager.Get(), e)
|
||||
} else {
|
||||
s.logger.Errorf("Destroying session %s with error: %s", s.destConnID, e)
|
||||
s.logger.Errorf("Destroying session %s with error: %s", s.connIDManager.Get(), e)
|
||||
}
|
||||
s.sessionRunner.Remove(s.srcConnID)
|
||||
s.closeChan <- closeError{err: e, sendClose: false, remote: false}
|
||||
|
|
|
@ -79,6 +79,7 @@ var _ = Describe("Session", func() {
|
|||
packer *MockPacker
|
||||
cryptoSetup *mocks.MockCryptoSetup
|
||||
)
|
||||
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
|
||||
getPacket := func(pn protocol.PacketNumber) *packedPacket {
|
||||
buffer := getPacketBuffer()
|
||||
|
@ -110,7 +111,7 @@ var _ = Describe("Session", func() {
|
|||
mconn,
|
||||
sessionRunner,
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
destConnID,
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
populateServerConfig(&Config{}),
|
||||
nil, // tls.Config
|
||||
|
@ -307,7 +308,7 @@ var _ = Describe("Session", func() {
|
|||
SequenceNumber: 10,
|
||||
ConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
}, 1, protocol.Encryption1RTT)).To(Succeed())
|
||||
Expect(sess.connIDManager.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4}))
|
||||
Expect(sess.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4}))
|
||||
})
|
||||
|
||||
It("handles PING frames", func() {
|
||||
|
@ -673,7 +674,7 @@ var _ = Describe("Session", func() {
|
|||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: sess.destConnID,
|
||||
DestConnectionID: destConnID,
|
||||
SrcConnectionID: sess.srcConnID,
|
||||
Length: 1,
|
||||
Version: sess.version,
|
||||
|
@ -685,7 +686,7 @@ var _ = Describe("Session", func() {
|
|||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: sess.destConnID,
|
||||
DestConnectionID: destConnID,
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
Length: 1,
|
||||
Version: sess.version,
|
||||
|
@ -711,7 +712,7 @@ var _ = Describe("Session", func() {
|
|||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: sess.destConnID,
|
||||
DestConnectionID: destConnID,
|
||||
SrcConnectionID: sess.srcConnID,
|
||||
Length: 1,
|
||||
Version: sess.version,
|
||||
|
@ -752,7 +753,7 @@ var _ = Describe("Session", func() {
|
|||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: sess.destConnID,
|
||||
SrcConnectionID: destConnID,
|
||||
Version: protocol.VersionTLS,
|
||||
Length: length,
|
||||
},
|
||||
|
@ -1507,6 +1508,7 @@ var _ = Describe("Client Session", func() {
|
|||
tlsConf *tls.Config
|
||||
quicConf *Config
|
||||
)
|
||||
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
|
||||
getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket {
|
||||
buf := &bytes.Buffer{}
|
||||
|
@ -1539,7 +1541,7 @@ var _ = Describe("Client Session", func() {
|
|||
sess = newClientSession(
|
||||
mconn,
|
||||
sessionRunner,
|
||||
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
|
||||
destConnID,
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
quicConf,
|
||||
tlsConf,
|
||||
|
@ -1571,7 +1573,6 @@ var _ = Describe("Client Session", func() {
|
|||
sess.run()
|
||||
}()
|
||||
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
|
||||
packer.EXPECT().ChangeDestConnectionID(newConnID)
|
||||
Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
|
@ -1627,7 +1628,6 @@ var _ = Describe("Client Session", func() {
|
|||
It("handles Retry packets", func() {
|
||||
cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
|
||||
packer.EXPECT().SetToken([]byte("foobar"))
|
||||
packer.EXPECT().ChangeDestConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
|
||||
Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeTrue())
|
||||
})
|
||||
|
||||
|
@ -1637,7 +1637,7 @@ var _ = Describe("Client Session", func() {
|
|||
})
|
||||
|
||||
It("ignores Retry packets if the server didn't change the connection ID", func() {
|
||||
validRetryHdr.SrcConnectionID = sess.destConnID
|
||||
validRetryHdr.SrcConnectionID = destConnID
|
||||
Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse())
|
||||
})
|
||||
|
||||
|
@ -1724,7 +1724,7 @@ var _ = Describe("Client Session", func() {
|
|||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: sess.destConnID,
|
||||
DestConnectionID: destConnID,
|
||||
SrcConnectionID: sess.srcConnID,
|
||||
Length: 1,
|
||||
Version: sess.version,
|
||||
|
@ -1736,7 +1736,7 @@ var _ = Describe("Client Session", func() {
|
|||
Header: wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
DestConnectionID: sess.destConnID,
|
||||
DestConnectionID: destConnID,
|
||||
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
|
||||
Length: 1,
|
||||
Version: sess.version,
|
||||
|
@ -1746,7 +1746,6 @@ var _ = Describe("Client Session", func() {
|
|||
}
|
||||
Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
|
||||
// Send one packet, which might change the connection ID.
|
||||
packer.EXPECT().ChangeDestConnectionID(sess.srcConnID).MaxTimes(1)
|
||||
// only EXPECT one call to the unpacker
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionInitial,
|
||||
|
@ -1762,7 +1761,7 @@ var _ = Describe("Client Session", func() {
|
|||
// the connection to immediately break down
|
||||
It("fails on Initial-level ACK for unsent packet", func() {
|
||||
ackFrame := testutils.ComposeAckFrame(0, 0)
|
||||
initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{ackFrame})
|
||||
initialPacket := testutils.ComposeInitialPacket(destConnID, sess.srcConnID, sess.version, destConnID, []wire.Frame{ackFrame})
|
||||
Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse())
|
||||
})
|
||||
|
||||
|
@ -1771,7 +1770,7 @@ var _ = Describe("Client Session", func() {
|
|||
It("fails on Initial-level CONNECTION_CLOSE frame", func() {
|
||||
sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
|
||||
connCloseFrame := testutils.ComposeConnCloseFrame()
|
||||
initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{connCloseFrame})
|
||||
initialPacket := testutils.ComposeInitialPacket(destConnID, sess.srcConnID, sess.version, destConnID, []wire.Frame{connCloseFrame})
|
||||
Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue())
|
||||
})
|
||||
|
||||
|
@ -1781,10 +1780,9 @@ var _ = Describe("Client Session", func() {
|
|||
newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
|
||||
cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID)
|
||||
packer.EXPECT().SetToken([]byte("foobar"))
|
||||
packer.EXPECT().ChangeDestConnectionID(newSrcConnID)
|
||||
|
||||
sess.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, sess.destConnID, sess.destConnID, []byte("foobar"), sess.version)))
|
||||
initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, nil)
|
||||
sess.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), sess.version)))
|
||||
initialPacket := testutils.ComposeInitialPacket(sess.connIDManager.Get(), sess.srcConnID, sess.version, sess.connIDManager.Get(), nil)
|
||||
Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse())
|
||||
})
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue