diff --git a/client.go b/client.go index 5a5cc914..88fc6df2 100644 --- a/client.go +++ b/client.go @@ -57,10 +57,10 @@ var _ packetHandler = &client{} var ( // make it possible to mock connection ID generation in the tests - generateConnectionID = protocol.GenerateConnectionID - generateDestConnectionID = protocol.GenerateDestinationConnectionID - errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") - errCloseSessionForRetry = errors.New("closing session in response to a stateless retry") + generateConnectionID = protocol.GenerateConnectionID + generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial + errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version") + errCloseSessionForRetry = errors.New("closing session in response to a stateless retry") ) // DialAddr establishes a new QUIC connection to a server. @@ -259,7 +259,7 @@ func (c *client) generateConnectionIDs() error { } destConnID := srcConnID if c.version.UsesTLS() { - destConnID, err = generateDestConnectionID() + destConnID, err = generateConnectionIDForInitial() if err != nil { return err } diff --git a/client_test.go b/client_test.go index b40aa27f..01f3d54e 100644 --- a/client_test.go +++ b/client_test.go @@ -94,22 +94,22 @@ var _ = Describe("Client", func() { Context("Dialing", func() { var origGenerateConnectionID func(int) (protocol.ConnectionID, error) - var origGenerateDestConnectionID func() (protocol.ConnectionID, error) + var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error) BeforeEach(func() { origGenerateConnectionID = generateConnectionID - origGenerateDestConnectionID = generateDestConnectionID + origGenerateConnectionIDForInitial = generateConnectionIDForInitial generateConnectionID = func(int) (protocol.ConnectionID, error) { return connID, nil } - generateDestConnectionID = func() (protocol.ConnectionID, error) { + generateConnectionIDForInitial = func() (protocol.ConnectionID, error) { return connID, nil } }) AfterEach(func() { generateConnectionID = origGenerateConnectionID - generateDestConnectionID = origGenerateDestConnectionID + generateConnectionIDForInitial = origGenerateConnectionIDForInitial }) It("resolves the address", func() { diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index beacbfcf..f99461b2 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -21,9 +21,9 @@ func GenerateConnectionID(len int) (ConnectionID, error) { return ConnectionID(b), nil } -// GenerateDestinationConnectionID generates a connection ID for the Initial packet. +// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. // It uses a length randomly chosen between 8 and 18 bytes. -func GenerateDestinationConnectionID() (ConnectionID, error) { +func GenerateConnectionIDForInitial() (ConnectionID, error) { r := make([]byte, 1) if _, err := rand.Read(r); err != nil { return nil, err diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go index 3d0d90e2..f0c7f7cc 100644 --- a/internal/protocol/connection_id_test.go +++ b/internal/protocol/connection_id_test.go @@ -27,7 +27,7 @@ var _ = Describe("Connection ID generation", func() { It("generates random length destination connection IDs", func() { var has8ByteConnID, has18ByteConnID bool for i := 0; i < 1000; i++ { - c, err := GenerateDestinationConnectionID() + c, err := GenerateConnectionIDForInitial() Expect(err).ToNot(HaveOccurred()) Expect(c.Len()).To(BeNumerically(">=", 8)) Expect(c.Len()).To(BeNumerically("<=", 18)) diff --git a/server_tls.go b/server_tls.go index b894e48b..c27eca58 100644 --- a/server_tls.go +++ b/server_tls.go @@ -128,6 +128,9 @@ func (s *serverTLS) handleInitialImpl(p *receivedPacket) (quicSession, protocol. mconf := s.mintConf.Clone() mconf.ExtensionHandler = extHandler + // A server is allowed to perform multiple Retries. + // It doesn't make much sense, but it's something that our API allows. + // In that case it must use a source connection ID of at least 8 bytes. connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { return nil, nil, err @@ -159,7 +162,7 @@ func (s *serverTLS) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error { if err != nil { return err } - connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + connID, err := protocol.GenerateConnectionIDForInitial() if err != nil { return err } diff --git a/server_tls_test.go b/server_tls_test.go index e9ef4665..8d1290a6 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -95,6 +95,7 @@ var _ = Describe("Stateless TLS handling", func() { replyHdr := parseHeader(conn.dataWritten.Bytes()) Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) + Expect(replyHdr.SrcConnectionID.Len()).To(BeNumerically(">=", protocol.MinConnectionIDLenInitial)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID)) Expect(replyHdr.Token).ToNot(BeEmpty())