diff --git a/conn_id_generator.go b/conn_id_generator.go index dbebe8c8..04649b5d 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -22,6 +22,8 @@ type connIDGenerator struct { retireConnectionID func(protocol.ConnectionID) replaceWithClosed func(protocol.ConnectionID, packetHandler) queueControlFrame func(wire.Frame) + + version protocol.VersionNumber } func newConnIDGenerator( @@ -33,6 +35,7 @@ func newConnIDGenerator( retireConnectionID func(protocol.ConnectionID), replaceWithClosed func(protocol.ConnectionID, packetHandler), queueControlFrame func(wire.Frame), + version protocol.VersionNumber, ) *connIDGenerator { m := &connIDGenerator{ connIDLen: initialConnectionID.Len(), @@ -43,6 +46,7 @@ func newConnIDGenerator( retireConnectionID: retireConnectionID, replaceWithClosed: replaceWithClosed, queueControlFrame: queueControlFrame, + version: version, } m.activeSrcConnIDs[0] = initialConnectionID m.initialClientDestConnID = initialClientDestConnID @@ -76,7 +80,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect if !ok { return nil } - if connID.Equal(sentWithDestConnID) && !RetireBugBackwardsCompatibilityMode { + if connID.Equal(sentWithDestConnID) && !protocol.UseRetireBugBackwardsCompatibilityMode(RetireBugBackwardsCompatibilityMode, m.version) { return qerr.NewError(qerr.ProtocolViolation, fmt.Sprintf("tried to retire connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID)) } m.retireConnectionID(connID) @@ -89,7 +93,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect } func (m *connIDGenerator) issueNewConnID() error { - if RetireBugBackwardsCompatibilityMode { + if protocol.UseRetireBugBackwardsCompatibilityMode(RetireBugBackwardsCompatibilityMode, m.version) { return nil } connID, err := protocol.GenerateConnectionID(m.connIDLen) diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 5ff7e361..ef8016d9 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -41,6 +41,7 @@ var _ = Describe("Connection ID Generator", func() { func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, + protocol.VersionDraft29, ) }) diff --git a/internal/protocol/version.go b/internal/protocol/version.go index 8e233dec..797f0009 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -62,6 +62,12 @@ func (vn VersionNumber) toGQUICVersion() int { return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) } +// UseRetireBugBackwardsCompatibilityMode says if it is necessary to use the backwards compatilibity mode. +// This is only the case if it 1. is enabled and 2. draft-29 is used. +func UseRetireBugBackwardsCompatibilityMode(enabled bool, v VersionNumber) bool { + return enabled && v == VersionDraft29 +} + // IsSupportedVersion returns true if the server supports this version func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { for _, t := range supported { diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index 04fadd80..dae3a2fa 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -49,6 +49,13 @@ var _ = Describe("Version", func() { } }) + It("says if backwards compatibility mode should be used", func() { + Expect(UseRetireBugBackwardsCompatibilityMode(true, VersionDraft29)).To(BeTrue()) + Expect(UseRetireBugBackwardsCompatibilityMode(true, VersionDraft32)).To(BeFalse()) + Expect(UseRetireBugBackwardsCompatibilityMode(false, VersionDraft29)).To(BeFalse()) + Expect(UseRetireBugBackwardsCompatibilityMode(false, VersionDraft32)).To(BeFalse()) + }) + Context("highest supported version", func() { It("finds the supported version", func() { supportedVersions := []VersionNumber{1, 2, 3} diff --git a/session.go b/session.go index 736b8d4e..3f8ff277 100644 --- a/session.go +++ b/session.go @@ -268,6 +268,7 @@ var newSession = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, + s.version, ) s.preSetup() s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( @@ -390,6 +391,7 @@ var newClientSession = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, + s.version, ) s.preSetup() s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(