use the connection ID manager to save the destination connection ID

This commit is contained in:
Marten Seemann 2019-10-25 17:58:53 +07:00
parent a321f9faa6
commit 772ffd3d20
7 changed files with 125 additions and 121 deletions

View file

@ -11,28 +11,34 @@ import (
type connIDManager struct {
queue utils.NewConnectionIDList
activeSequenceNumber uint64
activeConnectionID protocol.ConnectionID
queueControlFrame func(wire.Frame)
}
func newConnIDManager(queueControlFrame func(wire.Frame)) *connIDManager {
return &connIDManager{queueControlFrame: queueControlFrame}
func newConnIDManager(
initialDestConnID protocol.ConnectionID,
queueControlFrame func(wire.Frame),
) *connIDManager {
h := &connIDManager{queueControlFrame: queueControlFrame}
h.activeConnectionID = initialDestConnID
return h
}
func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
if err := h.add(f); err != nil {
return err
}
if h.queue.Len() > protocol.MaxActiveConnectionIDs {
// delete the first connection ID in the queue
val := h.queue.Remove(h.queue.Front())
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: val.SequenceNumber,
})
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
h.updateConnectionID()
}
return nil
}
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
// Retire elements in the queue.
// Doesn't retire the active connection ID.
var next *utils.NewConnectionIDElement
for el := h.queue.Front(); el != nil; el = next {
if el.Value.SequenceNumber >= f.RetirePriorTo {
@ -52,27 +58,55 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
ConnectionID: f.ConnectionID,
StatelessResetToken: &f.StatelessResetToken,
})
return nil
}
// insert a new element somewhere in the middle
for el := h.queue.Front(); el != nil; el = el.Next() {
if el.Value.SequenceNumber == f.SequenceNumber {
if !el.Value.ConnectionID.Equal(f.ConnectionID) {
return fmt.Errorf("received conflicting connection IDs for sequence number %d", f.SequenceNumber)
} else {
// insert a new element somewhere in the middle
for el := h.queue.Front(); el != nil; el = el.Next() {
if el.Value.SequenceNumber == f.SequenceNumber {
if !el.Value.ConnectionID.Equal(f.ConnectionID) {
return fmt.Errorf("received conflicting connection IDs for sequence number %d", f.SequenceNumber)
}
if *el.Value.StatelessResetToken != f.StatelessResetToken {
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", f.SequenceNumber)
}
break
}
if *el.Value.StatelessResetToken != f.StatelessResetToken {
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", f.SequenceNumber)
if el.Value.SequenceNumber > f.SequenceNumber {
h.queue.InsertBefore(utils.NewConnectionID{
SequenceNumber: f.SequenceNumber,
ConnectionID: f.ConnectionID,
StatelessResetToken: &f.StatelessResetToken,
}, el)
break
}
return nil
}
if el.Value.SequenceNumber > f.SequenceNumber {
h.queue.InsertBefore(utils.NewConnectionID{
SequenceNumber: f.SequenceNumber,
ConnectionID: f.ConnectionID,
StatelessResetToken: &f.StatelessResetToken,
}, el)
return nil
}
}
panic("should have processed NEW_CONNECTION_ID frame")
// Retire the active connection ID, if necessary.
if h.activeSequenceNumber < f.RetirePriorTo {
// The queue is guaranteed to have at least one element at this point.
h.updateConnectionID()
}
return nil
}
func (h *connIDManager) updateConnectionID() {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: h.activeSequenceNumber,
})
front := h.queue.Remove(h.queue.Front())
h.activeSequenceNumber = front.SequenceNumber
h.activeConnectionID = front.ConnectionID
}
// is called when the server performs a Retry
// and when the server changes the connection ID in the first Initial sent
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
if h.activeSequenceNumber != 0 {
panic("expected first connection ID to have sequence number 0")
}
h.activeConnectionID = newConnID
}
func (h *connIDManager) Get() protocol.ConnectionID {
return h.activeConnectionID
}

View file

@ -12,10 +12,11 @@ var _ = Describe("Connection ID Manager", func() {
m *connIDManager
frameQueue []wire.Frame
)
initialConnID := protocol.ConnectionID{1, 1, 1, 1}
BeforeEach(func() {
frameQueue = nil
m = newConnIDManager(func(f wire.Frame) {
m = newConnIDManager(initialConnID, func(f wire.Frame) {
frameQueue = append(frameQueue, f)
})
})
@ -28,10 +29,13 @@ var _ = Describe("Connection ID Manager", func() {
return val.ConnectionID, val.StatelessResetToken
}
It("returns nil if empty", func() {
c, rt := get()
Expect(c).To(BeNil())
Expect(rt).To(BeNil())
It("returns the initial connection ID", func() {
Expect(m.Get()).To(Equal(initialConnID))
})
It("changes the initial connection ID", func() {
m.ChangeInitialConnID(protocol.ConnectionID{1, 2, 3, 4, 5})
Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5}))
})
It("adds and gets connection IDs", func() {
@ -111,26 +115,28 @@ var _ = Describe("Connection ID Manager", func() {
SequenceNumber: 17,
ConnectionID: protocol.ConnectionID{3, 4, 5, 6},
})).To(Succeed())
Expect(frameQueue).To(HaveLen(2))
Expect(frameQueue).To(HaveLen(3))
Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(10))
Expect(frameQueue[1].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(13))
c, _ := get()
Expect(c).To(Equal(protocol.ConnectionID{3, 4, 5, 6}))
Expect(frameQueue[2].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero())
Expect(m.Get()).To(Equal(protocol.ConnectionID{3, 4, 5, 6}))
})
It("retires old connection IDs when the peer sends too many new ones", func() {
for i := uint8(0); i < protocol.MaxActiveConnectionIDs; i++ {
for i := uint8(1); i <= protocol.MaxActiveConnectionIDs; i++ {
Expect(m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: uint64(i),
ConnectionID: protocol.ConnectionID{i, i, i, i},
})).To(Succeed())
}
Expect(frameQueue).To(BeEmpty())
Expect(frameQueue).To(HaveLen(1))
Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero())
frameQueue = nil
Expect(m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: protocol.MaxActiveConnectionIDs,
SequenceNumber: protocol.MaxActiveConnectionIDs + 1,
ConnectionID: protocol.ConnectionID{1, 2, 3, 4},
})).To(Succeed())
Expect(frameQueue).To(HaveLen(1))
Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(0))
Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(1))
})
})

View file

@ -9,7 +9,6 @@ import (
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
@ -36,18 +35,6 @@ func (m *MockPacker) EXPECT() *MockPackerMockRecorder {
return m.recorder
}
// ChangeDestConnectionID mocks base method
func (m *MockPacker) ChangeDestConnectionID(arg0 protocol.ConnectionID) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ChangeDestConnectionID", arg0)
}
// ChangeDestConnectionID indicates an expected call of ChangeDestConnectionID
func (mr *MockPackerMockRecorder) ChangeDestConnectionID(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeDestConnectionID", reflect.TypeOf((*MockPacker)(nil).ChangeDestConnectionID), arg0)
}
// HandleTransportParameters mocks base method
func (m *MockPacker) HandleTransportParameters(arg0 *handshake.TransportParameters) {
m.ctrl.T.Helper()

View file

@ -21,7 +21,6 @@ type packer interface {
HandleTransportParameters(*handshake.TransportParameters)
SetToken([]byte)
ChangeDestConnectionID(protocol.ConnectionID)
}
type sealer interface {
@ -128,8 +127,8 @@ type ackFrameSource interface {
}
type packetPacker struct {
destConnID protocol.ConnectionID
srcConnID protocol.ConnectionID
srcConnID protocol.ConnectionID
getDestConnID func() protocol.ConnectionID
perspective protocol.Perspective
version protocol.VersionNumber
@ -155,8 +154,8 @@ type packetPacker struct {
var _ packer = &packetPacker{}
func newPacketPacker(
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
getDestConnID func() protocol.ConnectionID,
initialStream cryptoStream,
handshakeStream cryptoStream,
packetNumberManager packetNumberManager,
@ -170,7 +169,7 @@ func newPacketPacker(
) *packetPacker {
return &packetPacker{
cryptoSetup: cryptoSetup,
destConnID: destConnID,
getDestConnID: getDestConnID,
srcConnID: srcConnID,
initialStream: initialStream,
handshakeStream: handshakeStream,
@ -432,7 +431,7 @@ func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHea
hdr := &wire.ExtendedHeader{}
hdr.PacketNumber = pn
hdr.PacketNumberLen = pnLen
hdr.DestConnectionID = p.destConnID
hdr.DestConnectionID = p.getDestConnID()
hdr.KeyPhase = kp
return hdr
}
@ -442,7 +441,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
hdr := &wire.ExtendedHeader{}
hdr.PacketNumber = pn
hdr.PacketNumberLen = pnLen
hdr.DestConnectionID = p.destConnID
hdr.DestConnectionID = p.getDestConnID()
switch encLevel {
case protocol.EncryptionInitial:
@ -550,10 +549,6 @@ func (p *packetPacker) writeAndSealPacketWithPadding(
}, nil
}
func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
p.destConnID = connID
}
func (p *packetPacker) SetToken(token []byte) {
p.token = token
}

View file

@ -77,7 +77,7 @@ var _ = Describe("Packet packer", func() {
packer = newPacketPacker(
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
func() protocol.ConnectionID { return protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} },
initialStream,
handshakeStream,
pnManager,
@ -126,28 +126,12 @@ var _ = Describe("Packet packer", func() {
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
packer.srcConnID = srcConnID
packer.destConnID = destConnID
packer.getDestConnID = func() protocol.ConnectionID { return destConnID }
h := packer.getLongHeader(protocol.EncryptionHandshake)
Expect(h.SrcConnectionID).To(Equal(srcConnID))
Expect(h.DestConnectionID).To(Equal(destConnID))
})
It("changes the destination connection ID", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2)
srcConnID := protocol.ConnectionID{1, 1, 1, 1, 1, 1, 1, 1}
packer.srcConnID = srcConnID
dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
dest2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
packer.ChangeDestConnectionID(dest1)
h := packer.getLongHeader(protocol.EncryptionInitial)
Expect(h.SrcConnectionID).To(Equal(srcConnID))
Expect(h.DestConnectionID).To(Equal(dest1))
packer.ChangeDestConnectionID(dest2)
h = packer.getLongHeader(protocol.EncryptionInitial)
Expect(h.SrcConnectionID).To(Equal(srcConnID))
Expect(h.DestConnectionID).To(Equal(dest2))
})
It("gets a short header", func() {
pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4)
h := packer.getShortHeader(protocol.KeyPhaseOne)
@ -397,7 +381,7 @@ var _ = Describe("Packet packer", func() {
Expect(err).ToNot(HaveOccurred())
// cut off the tag that the mock sealer added
packet.raw = packet.raw[:len(packet.raw)-sealer.Overhead()]
hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.destConnID))
hdr, _, _, err := wire.ParsePacket(packet.raw, len(packer.getDestConnID()))
Expect(err).ToNot(HaveOccurred())
r := bytes.NewReader(packet.raw)
extHdr, err := hdr.ParseExtended(r, packer.version)

View file

@ -104,7 +104,6 @@ var errCloseForRecreating = errors.New("closing session in order to recreate it"
type session struct {
sessionRunner sessionRunner
destConnID protocol.ConnectionID
origDestConnID protocol.ConnectionID // if the server sends a Retry, this is the connection ID we used initially
srcConnID protocol.ConnectionID
@ -201,13 +200,13 @@ var newSession = func(
sessionRunner: runner,
config: conf,
srcConnID: srcConnID,
destConnID: destConnID,
tokenGenerator: tokenGenerator,
perspective: protocol.PerspectiveServer,
handshakeCompleteChan: make(chan struct{}),
logger: logger,
version: v,
}
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
s.preSetup()
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger)
initialStream := newCryptoStream()
@ -231,9 +230,10 @@ var newSession = func(
logger,
)
s.cryptoStreamHandler = cs
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
s.packer = newPacketPacker(
s.destConnID,
s.srcConnID,
s.connIDManager.Get,
initialStream,
handshakeStream,
s.sentPacketHandler,
@ -269,13 +269,13 @@ var newClientSession = func(
sessionRunner: runner,
config: conf,
srcConnID: srcConnID,
destConnID: destConnID,
perspective: protocol.PerspectiveClient,
handshakeCompleteChan: make(chan struct{}),
logger: logger,
initialVersion: initialVersion,
version: v,
}
s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame)
s.preSetup()
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger)
initialStream := newCryptoStream()
@ -285,7 +285,7 @@ var newClientSession = func(
initialStream,
handshakeStream,
oneRTTStream,
s.destConnID,
destConnID,
conn.RemoteAddr(),
params,
&handshakeRunner{
@ -303,8 +303,8 @@ var newClientSession = func(
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.version)
s.packer = newPacketPacker(
s.destConnID,
s.srcConnID,
s.connIDManager.Get,
initialStream,
handshakeStream,
s.sentPacketHandler,
@ -333,7 +333,6 @@ func (s *session) preSetup() {
s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue(s.version)
s.frameParser = wire.NewFrameParser(s.version)
s.connIDManager = newConnIDManager(s.queueControlFrame)
s.rttStats = &congestion.RTTStats{}
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
@ -601,8 +600,9 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, s.destConnID)
destConnID := s.connIDManager.Get()
if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(destConnID) {
s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, destConnID)
return false
}
// drop 0-RTT packets
@ -652,11 +652,12 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R
return false
}
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
if !hdr.OrigDestConnectionID.Equal(s.destConnID) {
s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, s.destConnID)
destConnID := s.connIDManager.Get()
if !hdr.OrigDestConnectionID.Equal(destConnID) {
s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, destConnID)
return false
}
if hdr.SrcConnectionID.Equal(s.destConnID) {
if hdr.SrcConnectionID.Equal(destConnID) {
s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
return false
}
@ -668,16 +669,16 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R
}
s.logger.Debugf("<- Received Retry")
s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID)
s.origDestConnID = s.destConnID
s.destConnID = hdr.SrcConnectionID
s.origDestConnID = destConnID
newDestConnID := hdr.SrcConnectionID
s.receivedRetry = true
if err := s.sentPacketHandler.ResetForRetry(); err != nil {
s.closeLocal(err)
return false
}
s.cryptoStreamHandler.ChangeConnectionID(s.destConnID)
s.cryptoStreamHandler.ChangeConnectionID(newDestConnID)
s.packer.SetToken(hdr.Token)
s.packer.ChangeDestConnectionID(s.destConnID)
s.connIDManager.ChangeInitialConnID(newDestConnID)
s.scheduleSending()
return true
}
@ -688,10 +689,9 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time
}
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.destConnID) {
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.connIDManager.Get()) {
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", packet.hdr.SrcConnectionID)
s.destConnID = packet.hdr.SrcConnectionID
s.packer.ChangeDestConnectionID(s.destConnID)
s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
}
s.receivedFirstPacket = true
@ -927,9 +927,9 @@ func (s *session) destroy(e error) {
func (s *session) destroyImpl(e error) {
s.closeOnce.Do(func() {
if nerr, ok := e.(net.Error); ok && nerr.Timeout() {
s.logger.Errorf("Destroying session %s: %s", s.destConnID, e)
s.logger.Errorf("Destroying session %s: %s", s.connIDManager.Get(), e)
} else {
s.logger.Errorf("Destroying session %s with error: %s", s.destConnID, e)
s.logger.Errorf("Destroying session %s with error: %s", s.connIDManager.Get(), e)
}
s.sessionRunner.Remove(s.srcConnID)
s.closeChan <- closeError{err: e, sendClose: false, remote: false}

View file

@ -79,6 +79,7 @@ var _ = Describe("Session", func() {
packer *MockPacker
cryptoSetup *mocks.MockCryptoSetup
)
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
getPacket := func(pn protocol.PacketNumber) *packedPacket {
buffer := getPacketBuffer()
@ -110,7 +111,7 @@ var _ = Describe("Session", func() {
mconn,
sessionRunner,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
destConnID,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
populateServerConfig(&Config{}),
nil, // tls.Config
@ -307,7 +308,7 @@ var _ = Describe("Session", func() {
SequenceNumber: 10,
ConnectionID: protocol.ConnectionID{1, 2, 3, 4},
}, 1, protocol.Encryption1RTT)).To(Succeed())
Expect(sess.connIDManager.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4}))
Expect(sess.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4}))
})
It("handles PING frames", func() {
@ -673,7 +674,7 @@ var _ = Describe("Session", func() {
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: sess.destConnID,
DestConnectionID: destConnID,
SrcConnectionID: sess.srcConnID,
Length: 1,
Version: sess.version,
@ -685,7 +686,7 @@ var _ = Describe("Session", func() {
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: sess.destConnID,
DestConnectionID: destConnID,
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
Length: 1,
Version: sess.version,
@ -711,7 +712,7 @@ var _ = Describe("Session", func() {
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: sess.destConnID,
DestConnectionID: destConnID,
SrcConnectionID: sess.srcConnID,
Length: 1,
Version: sess.version,
@ -752,7 +753,7 @@ var _ = Describe("Session", func() {
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
SrcConnectionID: sess.destConnID,
SrcConnectionID: destConnID,
Version: protocol.VersionTLS,
Length: length,
},
@ -1507,6 +1508,7 @@ var _ = Describe("Client Session", func() {
tlsConf *tls.Config
quicConf *Config
)
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket {
buf := &bytes.Buffer{}
@ -1539,7 +1541,7 @@ var _ = Describe("Client Session", func() {
sess = newClientSession(
mconn,
sessionRunner,
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
destConnID,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
quicConf,
tlsConf,
@ -1571,7 +1573,6 @@ var _ = Describe("Client Session", func() {
sess.run()
}()
newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}
packer.EXPECT().ChangeDestConnectionID(newConnID)
Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
@ -1627,7 +1628,6 @@ var _ = Describe("Client Session", func() {
It("handles Retry packets", func() {
cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
packer.EXPECT().SetToken([]byte("foobar"))
packer.EXPECT().ChangeDestConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})
Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeTrue())
})
@ -1637,7 +1637,7 @@ var _ = Describe("Client Session", func() {
})
It("ignores Retry packets if the server didn't change the connection ID", func() {
validRetryHdr.SrcConnectionID = sess.destConnID
validRetryHdr.SrcConnectionID = destConnID
Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse())
})
@ -1724,7 +1724,7 @@ var _ = Describe("Client Session", func() {
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: sess.destConnID,
DestConnectionID: destConnID,
SrcConnectionID: sess.srcConnID,
Length: 1,
Version: sess.version,
@ -1736,7 +1736,7 @@ var _ = Describe("Client Session", func() {
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: sess.destConnID,
DestConnectionID: destConnID,
SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef},
Length: 1,
Version: sess.version,
@ -1746,7 +1746,6 @@ var _ = Describe("Client Session", func() {
}
Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID))
// Send one packet, which might change the connection ID.
packer.EXPECT().ChangeDestConnectionID(sess.srcConnID).MaxTimes(1)
// only EXPECT one call to the unpacker
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
@ -1762,7 +1761,7 @@ var _ = Describe("Client Session", func() {
// the connection to immediately break down
It("fails on Initial-level ACK for unsent packet", func() {
ackFrame := testutils.ComposeAckFrame(0, 0)
initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{ackFrame})
initialPacket := testutils.ComposeInitialPacket(destConnID, sess.srcConnID, sess.version, destConnID, []wire.Frame{ackFrame})
Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse())
})
@ -1771,7 +1770,7 @@ var _ = Describe("Client Session", func() {
It("fails on Initial-level CONNECTION_CLOSE frame", func() {
sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
connCloseFrame := testutils.ComposeConnCloseFrame()
initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, []wire.Frame{connCloseFrame})
initialPacket := testutils.ComposeInitialPacket(destConnID, sess.srcConnID, sess.version, destConnID, []wire.Frame{connCloseFrame})
Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue())
})
@ -1781,10 +1780,9 @@ var _ = Describe("Client Session", func() {
newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID)
packer.EXPECT().SetToken([]byte("foobar"))
packer.EXPECT().ChangeDestConnectionID(newSrcConnID)
sess.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, sess.destConnID, sess.destConnID, []byte("foobar"), sess.version)))
initialPacket := testutils.ComposeInitialPacket(sess.destConnID, sess.srcConnID, sess.version, sess.destConnID, nil)
sess.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), sess.version)))
initialPacket := testutils.ComposeInitialPacket(sess.connIDManager.Get(), sess.srcConnID, sess.version, sess.connIDManager.Get(), nil)
Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse())
})