keep increasing the packet number after version negotiation and retry

This commit is contained in:
Marten Seemann 2018-12-22 22:15:35 +06:30
parent 1abf9e1b37
commit 178ac0dacb
8 changed files with 59 additions and 13 deletions

View file

@ -37,6 +37,8 @@ type client struct {
destConnID protocol.ConnectionID destConnID protocol.ConnectionID
origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry) origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
initialPacketNumber protocol.PacketNumber
initialVersion protocol.VersionNumber initialVersion protocol.VersionNumber
version protocol.VersionNumber version protocol.VersionNumber
@ -340,7 +342,7 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
c.version = newVersion c.version = newVersion
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID) c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.closeForRecreating() c.initialPacketNumber = c.session.closeForRecreating()
} }
func (c *client) handleRetryPacket(hdr *wire.Header) { func (c *client) handleRetryPacket(hdr *wire.Header) {
@ -366,7 +368,7 @@ func (c *client) handleRetryPacket(hdr *wire.Header) {
c.origDestConnID = c.destConnID c.origDestConnID = c.destConnID
c.destConnID = hdr.SrcConnectionID c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token c.token = hdr.Token
c.session.closeForRecreating() c.initialPacketNumber = c.session.closeForRecreating()
} }
func (c *client) createNewTLSSession(version protocol.VersionNumber) error { func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
@ -397,6 +399,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
c.srcConnID, c.srcConnID,
c.config, c.config,
c.tlsConf, c.tlsConf,
c.initialPacketNumber,
params, params,
c.initialVersion, c.initialVersion,
c.logger, c.logger,

View file

@ -38,6 +38,7 @@ var _ = Describe("Client", func() {
srcConnID protocol.ConnectionID, srcConnID protocol.ConnectionID,
conf *Config, conf *Config,
tlsConf *tls.Config, tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
params *handshake.TransportParameters, params *handshake.TransportParameters,
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
logger utils.Logger, logger utils.Logger,
@ -142,6 +143,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -172,6 +174,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
tlsConf *tls.Config, tlsConf *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -202,6 +205,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -239,6 +243,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -279,6 +284,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -324,6 +330,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -369,6 +376,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -484,6 +492,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
configP *Config, configP *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
params *handshake.TransportParameters, params *handshake.TransportParameters,
_ protocol.VersionNumber, /* initial version */ _ protocol.VersionNumber, /* initial version */
_ utils.Logger, _ utils.Logger,
@ -530,8 +539,9 @@ var _ = Describe("Client", func() {
sess1.EXPECT().run().DoAndReturn(func() error { sess1.EXPECT().run().DoAndReturn(func() error {
return <-run1 return <-run1
}) })
sess1.EXPECT().closeForRecreating().Do(func() { sess1.EXPECT().closeForRecreating().DoAndReturn(func() protocol.PacketNumber {
run1 <- errCloseForRecreating run1 <- errCloseForRecreating
return 42
}) })
sess2 := NewMockQuicSession(mockCtrl) sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run() sess2.EXPECT().run()
@ -547,6 +557,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
initialPacketNumber protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -554,9 +565,11 @@ var _ = Describe("Client", func() {
) (quicSession, error) { ) (quicSession, error) {
switch len(sessions) { switch len(sessions) {
case 2: // for the first session case 2: // for the first session
Expect(initialPacketNumber).To(BeZero())
Expect(origDestConnID).To(BeNil()) Expect(origDestConnID).To(BeNil())
Expect(destConnID).ToNot(BeNil()) Expect(destConnID).ToNot(BeNil())
case 1: // for the second session case 1: // for the second session
Expect(initialPacketNumber).To(Equal(protocol.PacketNumber(42)))
Expect(origDestConnID).To(Equal(connID)) Expect(origDestConnID).To(Equal(connID))
Expect(destConnID).ToNot(Equal(connID)) Expect(destConnID).ToNot(Equal(connID))
} }
@ -615,6 +628,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,
@ -654,6 +668,7 @@ var _ = Describe("Client", func() {
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ *Config, _ *Config,
_ *tls.Config, _ *tls.Config,
_ protocol.PacketNumber,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ utils.Logger, _ utils.Logger,

View file

@ -78,7 +78,11 @@ type sentPacketHandler struct {
} }
// NewSentPacketHandler creates a new sentPacketHandler // NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler { func NewSentPacketHandler(
initialPacketNumber protocol.PacketNumber,
rttStats *congestion.RTTStats,
logger utils.Logger,
) SentPacketHandler {
congestion := congestion.NewCubicSender( congestion := congestion.NewCubicSender(
congestion.DefaultClock{}, congestion.DefaultClock{},
rttStats, rttStats,
@ -88,7 +92,7 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) Se
) )
return &sentPacketHandler{ return &sentPacketHandler{
packetNumberGenerator: newPacketNumberGenerator(0, protocol.SkipPacketAveragePeriodLength), packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
packetHistory: newSentPacketHistory(), packetHistory: newSentPacketHistory(),
rttStats: rttStats, rttStats: rttStats,
congestion: congestion, congestion: congestion,
@ -144,8 +148,10 @@ func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retra
} }
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ { func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ { if h.logger.Debug() && h.lastSentPacketNumber != 0 {
h.logger.Debugf("Skipping packet number %#x", p) for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %#x", p)
}
} }
h.lastSentPacketNumber = packet.PacketNumber h.lastSentPacketNumber = packet.PacketNumber

View file

@ -49,7 +49,7 @@ var _ = Describe("SentPacketHandler", func() {
BeforeEach(func() { BeforeEach(func() {
rttStats := &congestion.RTTStats{} rttStats := &congestion.RTTStats{}
handler = NewSentPacketHandler(rttStats, utils.DefaultLogger).(*sentPacketHandler) handler = NewSentPacketHandler(42, rttStats, utils.DefaultLogger).(*sentPacketHandler)
handler.SetHandshakeComplete() handler.SetHandshakeComplete()
streamFrame = wire.StreamFrame{ streamFrame = wire.StreamFrame{
StreamID: 5, StreamID: 5,
@ -962,4 +962,17 @@ var _ = Describe("SentPacketHandler", func() {
Expect(packet).To(BeNil()) Expect(packet).To(BeNil())
}) })
}) })
Context("peeking and popping packet number", func() {
It("peeks and pops the initial packet number", func() {
pn, _ := handler.PeekPacketNumber()
Expect(pn).To(Equal(protocol.PacketNumber(42)))
Expect(handler.PopPacketNumber()).To(Equal(protocol.PacketNumber(42)))
})
It("peeks and pops beyond the initial packet number", func() {
Expect(handler.PopPacketNumber()).To(Equal(protocol.PacketNumber(42)))
Expect(handler.PopPacketNumber()).To(BeNumerically(">", 42))
})
})
}) })

View file

@ -200,8 +200,10 @@ func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call {
} }
// closeForRecreating mocks base method // closeForRecreating mocks base method
func (m *MockQuicSession) closeForRecreating() { func (m *MockQuicSession) closeForRecreating() protocol.PacketNumber {
m.ctrl.Call(m, "closeForRecreating") ret := m.ctrl.Call(m, "closeForRecreating")
ret0, _ := ret[0].(protocol.PacketNumber)
return ret0
} }
// closeForRecreating indicates an expected call of closeForRecreating // closeForRecreating indicates an expected call of closeForRecreating

View file

@ -43,7 +43,7 @@ type quicSession interface {
GetVersion() protocol.VersionNumber GetVersion() protocol.VersionNumber
run() error run() error
destroy(error) destroy(error)
closeForRecreating() closeForRecreating() protocol.PacketNumber
closeRemote(error) closeRemote(error)
} }

View file

@ -158,6 +158,7 @@ var newSession = func(
version: v, version: v,
} }
s.preSetup() s.preSetup()
s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.logger)
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
s.streamsMap = newStreamsMap( s.streamsMap = newStreamsMap(
@ -218,6 +219,7 @@ var newClientSession = func(
srcConnID protocol.ConnectionID, srcConnID protocol.ConnectionID,
conf *Config, conf *Config,
tlsConf *tls.Config, tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
params *handshake.TransportParameters, params *handshake.TransportParameters,
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
logger utils.Logger, logger utils.Logger,
@ -235,6 +237,7 @@ var newClientSession = func(
version: v, version: v,
} }
s.preSetup() s.preSetup()
s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.logger)
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
cs, clientHelloWritten, err := handshake.NewCryptoSetupClient( cs, clientHelloWritten, err := handshake.NewCryptoSetupClient(
@ -286,7 +289,6 @@ var newClientSession = func(
func (s *session) preSetup() { func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{} s.rttStats = &congestion.RTTStats{}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController( s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.InitialMaxData, protocol.InitialMaxData,
@ -720,8 +722,12 @@ func (s *session) destroy(e error) {
}) })
} }
func (s *session) closeForRecreating() { // closeForRecreating closes the session in order to recreate it immediately afterwards
// It returns the first packet number that should be used in the new session.
func (s *session) closeForRecreating() protocol.PacketNumber {
s.destroy(errCloseForRecreating) s.destroy(errCloseForRecreating)
nextPN, _ := s.sentPacketHandler.PeekPacketNumber()
return nextPN
} }
func (s *session) closeRemote(e error) { func (s *session) closeRemote(e error) {

View file

@ -1325,6 +1325,7 @@ var _ = Describe("Client Session", func() {
protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1},
populateClientConfig(&Config{}, true), populateClientConfig(&Config{}, true),
nil, // tls.Config nil, // tls.Config
42, // initial packet number
nil, // transport parameters nil, // transport parameters
protocol.VersionWhatever, protocol.VersionWhatever,
utils.DefaultLogger, utils.DefaultLogger,