use gomock to mock crypto.AEAD in the crypto setup tests

This commit is contained in:
Marten Seemann 2017-10-25 20:14:29 +07:00
parent e81795e49b
commit ddea1d83ab
6 changed files with 222 additions and 132 deletions

View file

@ -8,6 +8,7 @@ import (
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
@ -55,58 +56,8 @@ func (*mockSigner) GetLeafCert(sni string) ([]byte, error) {
return []byte("certuncompressed"), nil
}
type mockAEAD struct {
encLevel protocol.EncryptionLevel
sharedSecret []byte
}
var _ crypto.AEAD = &mockAEAD{}
func (m *mockAEAD) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
if cap(dst) < len(src)+12 {
dst = make([]byte, len(src)+12)
}
dst = dst[:len(src)+12]
copy(dst, src)
switch m.encLevel {
case protocol.EncryptionUnencrypted:
copy(dst[len(src):], []byte(" unencrypted"))
case protocol.EncryptionSecure:
copy(dst[len(src):], []byte(" normal sec"))
case protocol.EncryptionForwardSecure:
copy(dst[len(src):], []byte(" forward sec"))
default:
Fail("invalid encryption level")
}
return dst
}
func (m *mockAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
if m.encLevel == protocol.EncryptionUnencrypted && string(src) == "unencrypted" ||
m.encLevel == protocol.EncryptionForwardSecure && string(src) == "forward secure encrypted" ||
m.encLevel == protocol.EncryptionSecure && string(src) == "encrypted" {
return []byte("decrypted"), nil
}
return nil, errors.New("authentication failed")
}
func (m *mockAEAD) Overhead() int {
return 12
}
var expectedInitialNonceLen int
var expectedFSNonceLen int
func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
var encLevel protocol.EncryptionLevel
if forwardSecure {
encLevel = protocol.EncryptionForwardSecure
Expect(nonces).To(HaveLen(expectedFSNonceLen))
} else {
encLevel = protocol.EncryptionSecure
Expect(nonces).To(HaveLen(expectedInitialNonceLen))
}
return &mockAEAD{encLevel: encLevel, sharedSecret: sharedSecret}, nil
return mockcrypto.NewMockAEAD(mockCtrl), nil
}
type mockStream struct {
@ -180,11 +131,15 @@ var _ = Describe("Server Crypto Setup", func() {
sourceAddrValid bool
)
const (
expectedInitialNonceLen = 32
expectedFSNonceLen = 64
)
BeforeEach(func() {
var err error
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
expectedInitialNonceLen = 32
expectedFSNonceLen = 64
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2)
@ -222,7 +177,7 @@ var _ = Describe("Server Crypto Setup", func() {
cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid }
cs.keyDerivation = mockQuicCryptoKeyDerivation
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
cs.nullAEAD = &mockAEAD{encLevel: protocol.EncryptionUnencrypted}
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
cs.cryptoStream = stream
})
@ -232,7 +187,7 @@ var _ = Describe("Server Crypto Setup", func() {
Context("diversification nonce", func() {
BeforeEach(func() {
cs.secureAEAD = &mockAEAD{}
cs.secureAEAD = mockcrypto.NewMockAEAD(mockCtrl)
cs.receivedForwardSecurePacket = false
Expect(cs.DiversificationNonce()).To(BeEmpty())
@ -336,6 +291,20 @@ var _ = Describe("Server Crypto Setup", func() {
})
It("generates SHLO messages", func() {
var checkedSecure, checkedForwardSecure bool
cs.keyDerivation = func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
if forwardSecure {
Expect(nonces).To(HaveLen(expectedFSNonceLen))
checkedForwardSecure = true
Expect(sharedSecret).To(Equal([]byte("shared ephermal")))
} else {
Expect(nonces).To(HaveLen(expectedInitialNonceLen))
Expect(sharedSecret).To(Equal([]byte("shared key")))
checkedSecure = true
}
return mockcrypto.NewMockAEAD(mockCtrl), nil
}
response, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
TagPUBS: []byte("pubs-c"),
TagNONC: nonce32,
@ -351,12 +320,8 @@ var _ = Describe("Server Crypto Setup", func() {
utils.LittleEndian.WriteUint32(b, protocol.VersionNumberToTag(v))
Expect(response).To(ContainSubstring(string(b.Bytes())))
}
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(cs.secureAEAD.(*mockAEAD).encLevel).To(Equal(protocol.EncryptionSecure))
Expect(cs.secureAEAD.(*mockAEAD).sharedSecret).To(Equal([]byte("shared key")))
Expect(cs.forwardSecureAEAD).ToNot(BeNil())
Expect(cs.forwardSecureAEAD.(*mockAEAD).sharedSecret).To(Equal([]byte("shared ephermal")))
Expect(cs.forwardSecureAEAD.(*mockAEAD).encLevel).To(Equal(protocol.EncryptionForwardSecure))
Expect(checkedSecure).To(BeTrue())
Expect(checkedForwardSecure).To(BeTrue())
})
It("handles long handshake", func() {
@ -576,65 +541,79 @@ var _ = Describe("Server Crypto Setup", func() {
Context("null encryption", func() {
It("is used initially", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(10), []byte{}).Return([]byte("foobar signed"))
enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar unencrypted")))
d := sealer.Seal(nil, []byte("foobar"), 10, []byte{})
Expect(d).To(Equal([]byte("foobar signed")))
})
It("is used for crypto stream", func() {
It("is used for the crypto stream", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar unencrypted")))
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
})
It("is accepted initially", func() {
d, enc, err := cs.Open(nil, []byte("unencrypted"), 0, []byte{})
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(5), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("unencrypted"), 5, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
})
It("errors if the has the wrong hash", func() {
_, enc, err := cs.Open(nil, []byte("not unencrypted"), 0, []byte{})
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("not unencrypted"), protocol.PacketNumber(5), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err := cs.Open(nil, []byte("not unencrypted"), 5, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is still accepted after CHLO", func() {
doCHLO()
// it tries forward secure and secure decryption first
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{})
Expect(cs.secureAEAD).ToNot(BeNil())
_, enc, err := cs.Open(nil, []byte("unencrypted"), 0, []byte{})
_, enc, err := cs.Open(nil, []byte("unencrypted"), 99, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
})
It("is not accepted after receiving secure packet", func() {
doCHLO()
Expect(cs.secureAEAD).ToNot(BeNil())
d, enc, err := cs.Open(nil, []byte("encrypted"), 0, []byte{})
// first receive a secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
Expect(enc).To(Equal(protocol.EncryptionSecure))
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
_, enc, err = cs.Open(nil, []byte("foobar unencrypted"), 0, []byte{})
// now receive an unencrypted packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err = cs.Open(nil, []byte("unencrypted"), 99, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is not used after CHLO", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
enc, sealer := cs.GetSealer()
Expect(enc).ToNot(Equal(protocol.EncryptionUnencrypted))
d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).ToNot(Equal([]byte("foobar unencrypted")))
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
})
})
Context("initial encryption", func() {
It("is accepted after CHLO", func() {
doCHLO()
d, enc, err := cs.Open(nil, []byte("encrypted"), 0, []byte{})
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
Expect(enc).To(Equal(protocol.EncryptionSecure))
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
@ -642,34 +621,41 @@ var _ = Describe("Server Crypto Setup", func() {
It("is not accepted after receiving forward secure packet", func() {
doCHLO()
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 0, []byte{})
// receive a forward secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
Expect(err).ToNot(HaveOccurred())
_, enc, err := cs.Open(nil, []byte("encrypted"), 0, []byte{})
// receive a secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(12), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err := cs.Open(nil, []byte("encrypted"), 12, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is used for crypto stream", func() {
It("is used for the crypto stream", func() {
doCHLO()
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(1), []byte{}).Return([]byte("foobar crypto stream"))
enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionSecure))
d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar normal sec")))
d := sealer.Seal(nil, []byte("foobar"), 1, []byte{})
Expect(d).To(Equal([]byte("foobar crypto stream")))
})
})
Context("forward secure encryption", func() {
It("is used after the CHLO", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(20), []byte{}).Return([]byte("foobar forward sec"))
enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
d := sealer.Seal(nil, []byte("foobar"), 20, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec")))
})
It("regards the handshake as complete once it receives a forward encrypted packet", func() {
doCHLO()
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 0, []byte{})
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(BeClosed())
})
@ -677,18 +663,20 @@ var _ = Describe("Server Crypto Setup", func() {
Context("forcing encryption levels", func() {
It("forces null encryption", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(11), []byte{}).Return([]byte("foobar unencrypted"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
d := sealer.Seal(nil, []byte("foobar"), 11, []byte{})
Expect(d).To(Equal([]byte("foobar unencrypted")))
})
It("forces initial encryption", func() {
doCHLO()
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(12), []byte{}).Return([]byte("foobar secure"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
Expect(d).To(Equal([]byte("foobar normal sec")))
d := sealer.Seal(nil, []byte("foobar"), 12, []byte{})
Expect(d).To(Equal([]byte("foobar secure")))
})
It("errors if no AEAD for initial encryption is available", func() {
@ -699,9 +687,10 @@ var _ = Describe("Server Crypto Setup", func() {
It("forces forward-secure encryption", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(13), []byte{}).Return([]byte("foobar forward sec"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 0, []byte{})
d := sealer.Seal(nil, []byte("foobar"), 13, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec")))
})