always use connection IDs longer than 8 bytes when sending a Retry

A server is allowed to perform multiple Retries. There's little to gain
from doing so, but it's something our API allows. If a server performs
multiple Retries, it must use a connection ID that's at least 8 bytes
long. Only if it doesn't perform any further Retries it is allowed to
use shorter IDs. Therefore, we're on the safe side by always using a
long connection ID.
This shouldn't have a performance impact, since the server changes the
connection ID to a short value with the first Handshake packet it sends.
This commit is contained in:
Marten Seemann 2018-08-14 17:34:21 +07:00
parent 829edc04ab
commit 872e1747f4
6 changed files with 17 additions and 13 deletions

View file

@ -57,10 +57,10 @@ var _ packetHandler = &client{}
var ( var (
// make it possible to mock connection ID generation in the tests // make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID generateConnectionID = protocol.GenerateConnectionID
generateDestConnectionID = protocol.GenerateDestinationConnectionID generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry") errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
) )
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
@ -259,7 +259,7 @@ func (c *client) generateConnectionIDs() error {
} }
destConnID := srcConnID destConnID := srcConnID
if c.version.UsesTLS() { if c.version.UsesTLS() {
destConnID, err = generateDestConnectionID() destConnID, err = generateConnectionIDForInitial()
if err != nil { if err != nil {
return err return err
} }

View file

@ -94,22 +94,22 @@ var _ = Describe("Client", func() {
Context("Dialing", func() { Context("Dialing", func() {
var origGenerateConnectionID func(int) (protocol.ConnectionID, error) var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
var origGenerateDestConnectionID func() (protocol.ConnectionID, error) var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)
BeforeEach(func() { BeforeEach(func() {
origGenerateConnectionID = generateConnectionID origGenerateConnectionID = generateConnectionID
origGenerateDestConnectionID = generateDestConnectionID origGenerateConnectionIDForInitial = generateConnectionIDForInitial
generateConnectionID = func(int) (protocol.ConnectionID, error) { generateConnectionID = func(int) (protocol.ConnectionID, error) {
return connID, nil return connID, nil
} }
generateDestConnectionID = func() (protocol.ConnectionID, error) { generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
return connID, nil return connID, nil
} }
}) })
AfterEach(func() { AfterEach(func() {
generateConnectionID = origGenerateConnectionID generateConnectionID = origGenerateConnectionID
generateDestConnectionID = origGenerateDestConnectionID generateConnectionIDForInitial = origGenerateConnectionIDForInitial
}) })
It("resolves the address", func() { It("resolves the address", func() {

View file

@ -21,9 +21,9 @@ func GenerateConnectionID(len int) (ConnectionID, error) {
return ConnectionID(b), nil return ConnectionID(b), nil
} }
// GenerateDestinationConnectionID generates a connection ID for the Initial packet. // GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
// It uses a length randomly chosen between 8 and 18 bytes. // It uses a length randomly chosen between 8 and 18 bytes.
func GenerateDestinationConnectionID() (ConnectionID, error) { func GenerateConnectionIDForInitial() (ConnectionID, error) {
r := make([]byte, 1) r := make([]byte, 1)
if _, err := rand.Read(r); err != nil { if _, err := rand.Read(r); err != nil {
return nil, err return nil, err

View file

@ -27,7 +27,7 @@ var _ = Describe("Connection ID generation", func() {
It("generates random length destination connection IDs", func() { It("generates random length destination connection IDs", func() {
var has8ByteConnID, has18ByteConnID bool var has8ByteConnID, has18ByteConnID bool
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
c, err := GenerateDestinationConnectionID() c, err := GenerateConnectionIDForInitial()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(BeNumerically(">=", 8)) Expect(c.Len()).To(BeNumerically(">=", 8))
Expect(c.Len()).To(BeNumerically("<=", 18)) Expect(c.Len()).To(BeNumerically("<=", 18))

View file

@ -128,6 +128,9 @@ func (s *serverTLS) handleInitialImpl(p *receivedPacket) (quicSession, protocol.
mconf := s.mintConf.Clone() mconf := s.mintConf.Clone()
mconf.ExtensionHandler = extHandler mconf.ExtensionHandler = extHandler
// A server is allowed to perform multiple Retries.
// It doesn't make much sense, but it's something that our API allows.
// In that case it must use a source connection ID of at least 8 bytes.
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -159,7 +162,7 @@ func (s *serverTLS) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
if err != nil { if err != nil {
return err return err
} }
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) connID, err := protocol.GenerateConnectionIDForInitial()
if err != nil { if err != nil {
return err return err
} }

View file

@ -95,6 +95,7 @@ var _ = Describe("Stateless TLS handling", func() {
replyHdr := parseHeader(conn.dataWritten.Bytes()) replyHdr := parseHeader(conn.dataWritten.Bytes())
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.SrcConnectionID.Len()).To(BeNumerically(">=", protocol.MinConnectionIDLenInitial))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID)) Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.Token).ToNot(BeEmpty()) Expect(replyHdr.Token).ToNot(BeEmpty())