add a function to pass the retry token to the packet packer

This commit is contained in:
Marten Seemann 2019-02-01 15:28:11 +09:00
parent 2712626e66
commit 7a7e7ca6eb
4 changed files with 20 additions and 8 deletions

View file

@ -108,3 +108,13 @@ func (m *MockPacker) PackRetransmission(arg0 *ackhandler.Packet) ([]*packedPacke
func (mr *MockPackerMockRecorder) PackRetransmission(arg0 interface{}) *gomock.Call { func (mr *MockPackerMockRecorder) PackRetransmission(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackRetransmission", reflect.TypeOf((*MockPacker)(nil).PackRetransmission), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackRetransmission", reflect.TypeOf((*MockPacker)(nil).PackRetransmission), arg0)
} }
// SetToken mocks base method
func (m *MockPacker) SetToken(arg0 []byte) {
m.ctrl.Call(m, "SetToken", arg0)
}
// SetToken indicates an expected call of SetToken
func (mr *MockPackerMockRecorder) SetToken(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetToken", reflect.TypeOf((*MockPacker)(nil).SetToken), arg0)
}

View file

@ -21,6 +21,7 @@ type packer interface {
PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
HandleTransportParameters(*handshake.TransportParameters) HandleTransportParameters(*handshake.TransportParameters)
SetToken([]byte)
ChangeDestConnectionID(protocol.ConnectionID) ChangeDestConnectionID(protocol.ConnectionID)
} }
@ -123,7 +124,6 @@ func newPacketPacker(
handshakeStream cryptoStream, handshakeStream cryptoStream,
packetNumberManager packetNumberManager, packetNumberManager packetNumberManager,
remoteAddr net.Addr, // only used for determining the max packet size remoteAddr net.Addr, // only used for determining the max packet size
token []byte,
cryptoSetup sealingManager, cryptoSetup sealingManager,
framer frameSource, framer frameSource,
acks ackFrameSource, acks ackFrameSource,
@ -132,7 +132,6 @@ func newPacketPacker(
) *packetPacker { ) *packetPacker {
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
token: token,
destConnID: destConnID, destConnID: destConnID,
srcConnID: srcConnID, srcConnID: srcConnID,
initialStream: initialStream, initialStream: initialStream,
@ -476,6 +475,10 @@ func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
p.destConnID = connID p.destConnID = connID
} }
func (p *packetPacker) SetToken(token []byte) {
p.token = token
}
func (p *packetPacker) HandleTransportParameters(params *handshake.TransportParameters) { func (p *packetPacker) HandleTransportParameters(params *handshake.TransportParameters) {
if params.MaxPacketSize != 0 { if params.MaxPacketSize != 0 {
p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxPacketSize) p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxPacketSize)

View file

@ -26,7 +26,6 @@ var _ = Describe("Packet packer", func() {
handshakeStream *MockCryptoStream handshakeStream *MockCryptoStream
sealingManager *MockSealingManager sealingManager *MockSealingManager
pnManager *mockackhandler.MockSentPacketHandler pnManager *mockackhandler.MockSentPacketHandler
token []byte
) )
checkLength := func(data []byte) { checkLength := func(data []byte) {
@ -57,7 +56,6 @@ var _ = Describe("Packet packer", func() {
BeforeEach(func() { BeforeEach(func() {
rand.Seed(GinkgoRandomSeed()) rand.Seed(GinkgoRandomSeed())
version := protocol.VersionTLS version := protocol.VersionTLS
token = []byte("initial token")
mockSender := NewMockStreamSender(mockCtrl) mockSender := NewMockStreamSender(mockCtrl)
mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes()
initialStream = NewMockCryptoStream(mockCtrl) initialStream = NewMockCryptoStream(mockCtrl)
@ -74,7 +72,6 @@ var _ = Describe("Packet packer", func() {
handshakeStream, handshakeStream,
pnManager, pnManager,
&net.TCPAddr{}, &net.TCPAddr{},
token, // token
sealingManager, sealingManager,
framer, framer,
ackFramer, ackFramer,
@ -772,6 +769,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("pads Initial packets to the required minimum packet size", func() { It("pads Initial packets to the required minimum packet size", func() {
token := []byte("initial token")
packer.SetToken(token)
f := &wire.CryptoFrame{Data: []byte("foobar")} f := &wire.CryptoFrame{Data: []byte("foobar")}
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
@ -856,7 +855,6 @@ var _ = Describe("Packet packer", func() {
packer.perspective = protocol.PerspectiveClient packer.perspective = protocol.PerspectiveClient
packet, err := packer.PackPacket() packet, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packet.header.Token).To(Equal(token))
Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize))
Expect(packet.frames).To(HaveLen(2)) Expect(packet.frames).To(HaveLen(2))
Expect(packet.frames[0]).To(Equal(ack)) Expect(packet.frames[0]).To(Equal(ack))
@ -884,6 +882,8 @@ var _ = Describe("Packet packer", func() {
}) })
It("packs a retransmission for an Initial packet", func() { It("packs a retransmission for an Initial packet", func() {
token := []byte("initial token")
packer.SetToken(token)
pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PeekPacketNumber().Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2)
pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42)) pnManager.EXPECT().PopPacketNumber().Return(protocol.PacketNumber(0x42))
sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil)

View file

@ -202,7 +202,6 @@ var newSession = func(
handshakeStream, handshakeStream,
s.sentPacketHandler, s.sentPacketHandler,
s.RemoteAddr(), s.RemoteAddr(),
nil, // no token
cs, cs,
s.framer, s.framer,
s.receivedPacketHandler, s.receivedPacketHandler,
@ -287,13 +286,13 @@ var newClientSession = func(
handshakeStream, handshakeStream,
s.sentPacketHandler, s.sentPacketHandler,
s.RemoteAddr(), s.RemoteAddr(),
token,
cs, cs,
s.framer, s.framer,
s.receivedPacketHandler, s.receivedPacketHandler,
s.perspective, s.perspective,
s.version, s.version,
) )
s.packer.SetToken(token)
return s, s.postSetup() return s, s.postSetup()
} }