mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 13:47:35 +03:00
always use connection IDs longer than 8 bytes when sending a Retry
A server is allowed to perform multiple Retries. There's little to gain from doing so, but it's something our API allows. If a server performs multiple Retries, it must use a connection ID that's at least 8 bytes long. Only if it doesn't perform any further Retries it is allowed to use shorter IDs. Therefore, we're on the safe side by always using a long connection ID. This shouldn't have a performance impact, since the server changes the connection ID to a short value with the first Handshake packet it sends.
This commit is contained in:
parent
829edc04ab
commit
872e1747f4
6 changed files with 17 additions and 13 deletions
10
client.go
10
client.go
|
@ -57,10 +57,10 @@ var _ packetHandler = &client{}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// make it possible to mock connection ID generation in the tests
|
// make it possible to mock connection ID generation in the tests
|
||||||
generateConnectionID = protocol.GenerateConnectionID
|
generateConnectionID = protocol.GenerateConnectionID
|
||||||
generateDestConnectionID = protocol.GenerateDestinationConnectionID
|
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||||
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
|
||||||
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
|
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
|
||||||
)
|
)
|
||||||
|
|
||||||
// DialAddr establishes a new QUIC connection to a server.
|
// DialAddr establishes a new QUIC connection to a server.
|
||||||
|
@ -259,7 +259,7 @@ func (c *client) generateConnectionIDs() error {
|
||||||
}
|
}
|
||||||
destConnID := srcConnID
|
destConnID := srcConnID
|
||||||
if c.version.UsesTLS() {
|
if c.version.UsesTLS() {
|
||||||
destConnID, err = generateDestConnectionID()
|
destConnID, err = generateConnectionIDForInitial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,22 +94,22 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
Context("Dialing", func() {
|
Context("Dialing", func() {
|
||||||
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
|
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
|
||||||
var origGenerateDestConnectionID func() (protocol.ConnectionID, error)
|
var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
origGenerateConnectionID = generateConnectionID
|
origGenerateConnectionID = generateConnectionID
|
||||||
origGenerateDestConnectionID = generateDestConnectionID
|
origGenerateConnectionIDForInitial = generateConnectionIDForInitial
|
||||||
generateConnectionID = func(int) (protocol.ConnectionID, error) {
|
generateConnectionID = func(int) (protocol.ConnectionID, error) {
|
||||||
return connID, nil
|
return connID, nil
|
||||||
}
|
}
|
||||||
generateDestConnectionID = func() (protocol.ConnectionID, error) {
|
generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
|
||||||
return connID, nil
|
return connID, nil
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
generateConnectionID = origGenerateConnectionID
|
generateConnectionID = origGenerateConnectionID
|
||||||
generateDestConnectionID = origGenerateDestConnectionID
|
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
|
||||||
})
|
})
|
||||||
|
|
||||||
It("resolves the address", func() {
|
It("resolves the address", func() {
|
||||||
|
|
|
@ -21,9 +21,9 @@ func GenerateConnectionID(len int) (ConnectionID, error) {
|
||||||
return ConnectionID(b), nil
|
return ConnectionID(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateDestinationConnectionID generates a connection ID for the Initial packet.
|
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
|
||||||
// It uses a length randomly chosen between 8 and 18 bytes.
|
// It uses a length randomly chosen between 8 and 18 bytes.
|
||||||
func GenerateDestinationConnectionID() (ConnectionID, error) {
|
func GenerateConnectionIDForInitial() (ConnectionID, error) {
|
||||||
r := make([]byte, 1)
|
r := make([]byte, 1)
|
||||||
if _, err := rand.Read(r); err != nil {
|
if _, err := rand.Read(r); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -27,7 +27,7 @@ var _ = Describe("Connection ID generation", func() {
|
||||||
It("generates random length destination connection IDs", func() {
|
It("generates random length destination connection IDs", func() {
|
||||||
var has8ByteConnID, has18ByteConnID bool
|
var has8ByteConnID, has18ByteConnID bool
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
c, err := GenerateDestinationConnectionID()
|
c, err := GenerateConnectionIDForInitial()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(c.Len()).To(BeNumerically(">=", 8))
|
Expect(c.Len()).To(BeNumerically(">=", 8))
|
||||||
Expect(c.Len()).To(BeNumerically("<=", 18))
|
Expect(c.Len()).To(BeNumerically("<=", 18))
|
||||||
|
|
|
@ -128,6 +128,9 @@ func (s *serverTLS) handleInitialImpl(p *receivedPacket) (quicSession, protocol.
|
||||||
mconf := s.mintConf.Clone()
|
mconf := s.mintConf.Clone()
|
||||||
mconf.ExtensionHandler = extHandler
|
mconf.ExtensionHandler = extHandler
|
||||||
|
|
||||||
|
// A server is allowed to perform multiple Retries.
|
||||||
|
// It doesn't make much sense, but it's something that our API allows.
|
||||||
|
// In that case it must use a source connection ID of at least 8 bytes.
|
||||||
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -159,7 +162,7 @@ func (s *serverTLS) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
|
connID, err := protocol.GenerateConnectionIDForInitial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,6 +95,7 @@ var _ = Describe("Stateless TLS handling", func() {
|
||||||
replyHdr := parseHeader(conn.dataWritten.Bytes())
|
replyHdr := parseHeader(conn.dataWritten.Bytes())
|
||||||
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
|
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
|
||||||
|
Expect(replyHdr.SrcConnectionID.Len()).To(BeNumerically(">=", protocol.MinConnectionIDLenInitial))
|
||||||
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
|
||||||
Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID))
|
Expect(replyHdr.OrigDestConnectionID).To(Equal(hdr.DestConnectionID))
|
||||||
Expect(replyHdr.Token).ToNot(BeEmpty())
|
Expect(replyHdr.Token).ToNot(BeEmpty())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue