diff --git a/connection_test.go b/connection_test.go index 67b02443..44564f84 100644 --- a/connection_test.go +++ b/connection_test.go @@ -100,8 +100,7 @@ var _ = Describe("Connection", func() { mconn.EXPECT().capabilities().DoAndReturn(func() connCapabilities { return capabilities }).AnyTimes() mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) + tokenGenerator := handshake.NewTokenGenerator([32]byte{0xa, 0xb, 0xc}) tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) diff --git a/fuzzing/tokens/fuzz.go b/fuzzing/tokens/fuzz.go index ea261fb6..d0043ead 100644 --- a/fuzzing/tokens/fuzz.go +++ b/fuzzing/tokens/fuzz.go @@ -2,24 +2,22 @@ package tokens import ( "encoding/binary" - "math/rand" "net" "time" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" ) func Fuzz(data []byte) int { - if len(data) < 8 { + if len(data) < 32 { return -1 } - seed := binary.BigEndian.Uint64(data[:8]) - data = data[8:] - tg, err := handshake.NewTokenGenerator(rand.New(rand.NewSource(int64(seed)))) - if err != nil { - panic(err) - } + var key quic.TokenGeneratorKey + copy(key[:], data[:32]) + data = data[32:] + tg := handshake.NewTokenGenerator(key) if len(data) < 1 { return -1 } diff --git a/interface.go b/interface.go index 746ad12e..f5ee28d8 100644 --- a/interface.go +++ b/interface.go @@ -212,6 +212,9 @@ type EarlyConnection interface { // StatelessResetKey is a key used to derive stateless reset tokens. type StatelessResetKey [32]byte +// TokenGeneratorKey is a key used to encrypt session resumption tokens. +type TokenGeneratorKey = handshake.TokenProtectorKey + // A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. // It is not able to handle QUIC Connection IDs longer than 20 bytes, // as they are allowed by RFC 8999. diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index e5e90bb3..2d91e6b2 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/asn1" "fmt" - "io" "net" "time" @@ -45,15 +44,9 @@ type TokenGenerator struct { tokenProtector tokenProtector } -// NewTokenGenerator initializes a new TookenGenerator -func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { - tokenProtector, err := newTokenProtector(rand) - if err != nil { - return nil, err - } - return &TokenGenerator{ - tokenProtector: tokenProtector, - }, nil +// NewTokenGenerator initializes a new TokenGenerator +func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator { + return &TokenGenerator{tokenProtector: newTokenProtector(key)} } // NewRetryToken generates a new token for a Retry for a given source address diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 870ba3ed..7abc53b4 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -16,9 +16,9 @@ var _ = Describe("Token Generator", func() { var tokenGen *TokenGenerator BeforeEach(func() { - var err error - tokenGen, err = NewTokenGenerator(rand.Reader) - Expect(err).ToNot(HaveOccurred()) + var key TokenProtectorKey + rand.Read(key[:]) + tokenGen = NewTokenGenerator(key) }) It("generates a token", func() { diff --git a/internal/handshake/token_protector.go b/internal/handshake/token_protector.go index 650f230b..6dcf7f77 100644 --- a/internal/handshake/token_protector.go +++ b/internal/handshake/token_protector.go @@ -3,6 +3,7 @@ package handshake import ( "crypto/aes" "crypto/cipher" + "crypto/rand" "crypto/sha256" "fmt" "io" @@ -10,6 +11,9 @@ import ( "golang.org/x/crypto/hkdf" ) +// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens. +type TokenProtectorKey [32]byte + // TokenProtector is used to create and verify a token type tokenProtector interface { // NewToken creates a new token @@ -18,40 +22,29 @@ type tokenProtector interface { DecodeToken([]byte) ([]byte, error) } -const ( - tokenSecretSize = 32 - tokenNonceSize = 32 -) +const tokenNonceSize = 32 // tokenProtector is used to create and verify a token type tokenProtectorImpl struct { - rand io.Reader - secret []byte + key TokenProtectorKey } // newTokenProtector creates a source for source address tokens -func newTokenProtector(rand io.Reader) (tokenProtector, error) { - secret := make([]byte, tokenSecretSize) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - return &tokenProtectorImpl{ - rand: rand, - secret: secret, - }, nil +func newTokenProtector(key TokenProtectorKey) tokenProtector { + return &tokenProtectorImpl{key: key} } // NewToken encodes data into a new token. func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, tokenNonceSize) - if _, err := s.rand.Read(nonce); err != nil { + var nonce [tokenNonceSize]byte + if _, err := rand.Read(nonce[:]); err != nil { return nil, err } - aead, aeadNonce, err := s.createAEAD(nonce) + aead, aeadNonce, err := s.createAEAD(nonce[:]) if err != nil { return nil, err } - return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil + return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil } // DecodeToken decodes a token. @@ -68,7 +61,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { } func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { - h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) + h := hkdf.New(sha256.New, s.key[:], nonce[:], []byte("quic-go token source")) key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 if _, err := io.ReadFull(h, key); err != nil { return nil, nil, err diff --git a/internal/handshake/token_protector_test.go b/internal/handshake/token_protector_test.go index 03cd5320..74eb1f0c 100644 --- a/internal/handshake/token_protector_test.go +++ b/internal/handshake/token_protector_test.go @@ -7,41 +7,17 @@ import ( . "github.com/onsi/gomega" ) -type zeroReader struct{} - -func (r *zeroReader) Read(b []byte) (int, error) { - for i := range b { - b[i] = 0 - } - return len(b), nil -} - var _ = Describe("Token Protector", func() { var tp tokenProtector BeforeEach(func() { + var key TokenProtectorKey + rand.Read(key[:]) var err error - tp, err = newTokenProtector(rand.Reader) + tp = newTokenProtector(key) Expect(err).ToNot(HaveOccurred()) }) - It("uses the random source", func() { - tp1, err := newTokenProtector(&zeroReader{}) - Expect(err).ToNot(HaveOccurred()) - tp2, err := newTokenProtector(&zeroReader{}) - Expect(err).ToNot(HaveOccurred()) - t1, err := tp1.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - t2, err := tp2.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t1).To(Equal(t2)) - tp3, err := newTokenProtector(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - t3, err := tp3.NewToken([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(t3).ToNot(Equal(t1)) - }) - It("encodes and decodes tokens", func() { token, err := tp.NewToken([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) @@ -51,11 +27,34 @@ var _ = Describe("Token Protector", func() { Expect(decoded).To(Equal([]byte("foobar"))) }) - It("fails deconding invalid tokens", func() { + It("uses the different keys", func() { + var key1, key2 TokenProtectorKey + rand.Read(key1[:]) + rand.Read(key2[:]) + tp1 := newTokenProtector(key1) + tp2 := newTokenProtector(key2) + t1, err := tp1.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + t2, err := tp2.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + + _, err = tp1.DecodeToken(t1) + Expect(err).ToNot(HaveOccurred()) + _, err = tp1.DecodeToken(t2) + Expect(err).To(HaveOccurred()) + + // now create another token protector, reusing key1 + tp3 := newTokenProtector(key1) + _, err = tp3.DecodeToken(t1) + Expect(err).ToNot(HaveOccurred()) + _, err = tp3.DecodeToken(t2) + Expect(err).To(HaveOccurred()) + }) + + It("doesn't decode invalid tokens", func() { token, err := tp.NewToken([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - token = token[1:] // remove the first byte - _, err = tp.DecodeToken(token) + _, err = tp.DecodeToken(token[1:]) // the token is invalid without the first byte Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("message authentication failed")) }) diff --git a/server.go b/server.go index d41c6465..671ae971 100644 --- a/server.go +++ b/server.go @@ -2,7 +2,6 @@ package quic import ( "context" - "crypto/rand" "crypto/tls" "errors" "fmt" @@ -227,18 +226,15 @@ func newServer( config *Config, tracer logging.Tracer, onClose func(), + tokenGeneratorKey TokenGeneratorKey, disableVersionNegotiation bool, acceptEarly bool, -) (*baseServer, error) { - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - if err != nil { - return nil, err - } +) *baseServer { s := &baseServer{ conn: conn, tlsConf: tlsConf, config: config, - tokenGenerator: tokenGenerator, + tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), connIDGenerator: connIDGenerator, connHandler: connHandler, connQueue: make(chan quicConn), @@ -260,7 +256,7 @@ func newServer( go s.run() go s.runSendQueue() s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - return s, nil + return s } func (s *baseServer) run() { diff --git a/transport.go b/transport.go index c021be77..b5855329 100644 --- a/transport.go +++ b/transport.go @@ -57,6 +57,12 @@ type Transport struct { // See section 10.3 of RFC 9000 for details. StatelessResetKey *StatelessResetKey + // The TokenGeneratorKey is used to encrypt session resumption tokens. + // If no key is configured, a random key will be generated. + // If multiple servers are authoritative for the same domain, they should use the same key, + // see section 8.1.3 of RFC 9000 for details. + TokenGeneratorKey *TokenGeneratorKey + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. // This can be useful if version information is exchanged out-of-band. // It has no effect for clients. @@ -136,7 +142,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo if err := t.init(false); err != nil { return nil, err } - s, err := newServer( + s := newServer( t.conn, t.handlerMap, t.connIDGenerator, @@ -144,12 +150,10 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo conf, t.Tracer, t.closeServer, + *t.TokenGeneratorKey, t.DisableVersionNegotiationPackets, allow0RTT, ) - if err != nil { - return nil, err - } t.server = s return s, nil } @@ -203,6 +207,14 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.closeQueue = make(chan closePacket, 4) t.statelessResetQueue = make(chan receivedPacket, 4) + if t.TokenGeneratorKey == nil { + var key TokenGeneratorKey + if _, err := rand.Read(key[:]); err != nil { + t.initErr = err + return + } + t.TokenGeneratorKey = &key + } if t.ConnectionIDGenerator != nil { t.connIDGenerator = t.ConnectionIDGenerator