inject a random source into the token protector

This commit is contained in:
Marten Seemann 2020-08-23 17:03:15 +07:00
parent 166d91ae0f
commit 556bf18dbf
6 changed files with 44 additions and 10 deletions

View file

@ -3,6 +3,7 @@ package handshake
import (
"encoding/asn1"
"fmt"
"io"
"net"
"time"
@ -39,8 +40,8 @@ type TokenGenerator struct {
}
// NewTokenGenerator initializes a new TookenGenerator
func NewTokenGenerator() (*TokenGenerator, error) {
tokenProtector, err := newTokenProtector()
func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) {
tokenProtector, err := newTokenProtector(rand)
if err != nil {
return nil, err
}

View file

@ -1,6 +1,7 @@
package handshake
import (
"crypto/rand"
"encoding/asn1"
"net"
"time"
@ -16,7 +17,7 @@ var _ = Describe("Token Generator", func() {
BeforeEach(func() {
var err error
tokenGen, err = NewTokenGenerator()
tokenGen, err = NewTokenGenerator(rand.Reader)
Expect(err).ToNot(HaveOccurred())
})

View file

@ -3,7 +3,6 @@ package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
@ -26,22 +25,26 @@ const (
// tokenProtector is used to create and verify a token
type tokenProtectorImpl struct {
rand io.Reader
secret []byte
}
// newTokenProtector creates a source for source address tokens
func newTokenProtector() (tokenProtector, error) {
func newTokenProtector(rand io.Reader) (tokenProtector, error) {
secret := make([]byte, tokenSecretSize)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
return &tokenProtectorImpl{secret: secret}, nil
return &tokenProtectorImpl{
rand: rand,
secret: secret,
}, nil
}
// NewToken encodes data into a new token.
func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, tokenNonceSize)
if _, err := rand.Read(nonce); err != nil {
if _, err := s.rand.Read(nonce); err != nil {
return nil, err
}
aead, aeadNonce, err := s.createAEAD(nonce)

View file

@ -1,19 +1,47 @@
package handshake
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type zeroReader struct{}
func (r *zeroReader) Read(b []byte) (int, error) {
for i := range b {
b[i] = 0
}
return len(b), nil
}
var _ = Describe("Token Protector", func() {
var tp tokenProtector
BeforeEach(func() {
var err error
tp, err = newTokenProtector()
tp, err = newTokenProtector(rand.Reader)
Expect(err).ToNot(HaveOccurred())
})
It("uses the random source", func() {
tp1, err := newTokenProtector(&zeroReader{})
Expect(err).ToNot(HaveOccurred())
tp2, err := newTokenProtector(&zeroReader{})
Expect(err).ToNot(HaveOccurred())
t1, err := tp1.NewToken([]byte("foo"))
Expect(err).ToNot(HaveOccurred())
t2, err := tp2.NewToken([]byte("foo"))
Expect(err).ToNot(HaveOccurred())
Expect(t1).To(Equal(t2))
tp3, err := newTokenProtector(rand.Reader)
Expect(err).ToNot(HaveOccurred())
t3, err := tp3.NewToken([]byte("foo"))
Expect(err).ToNot(HaveOccurred())
Expect(t3).ToNot(Equal(t1))
})
It("encodes and decodes tokens", func() {
token, err := tp.NewToken([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())

View file

@ -3,6 +3,7 @@ package quic
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
@ -185,7 +186,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
if err != nil {
return nil, err
}
tokenGenerator, err := handshake.NewTokenGenerator()
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
if err != nil {
return nil, err
}

View file

@ -86,7 +86,7 @@ var _ = Describe("Session", func() {
mconn = NewMockSendConn(mockCtrl)
mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
tokenGenerator, err := handshake.NewTokenGenerator()
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
Expect(err).ToNot(HaveOccurred())
tracer = mocks.NewMockConnectionTracer(mockCtrl)
tracer.EXPECT().SentTransportParameters(gomock.Any())