pass the crypto stream to the crypto setup constructor

The crypto stream is opened during the session setup. Passing it to the
crypto setup directly helps simplify the constructor.
This commit is contained in:
Marten Seemann 2017-10-21 09:48:25 +07:00
parent a88da29433
commit 282b423f7d
10 changed files with 75 additions and 74 deletions

View file

@ -66,6 +66,7 @@ var (
// NewCryptoSetupClient creates a new CryptoSetup instance for a client // NewCryptoSetupClient creates a new CryptoSetup instance for a client
func NewCryptoSetupClient( func NewCryptoSetupClient(
cryptoStream io.ReadWriter,
hostname string, hostname string,
connID protocol.ConnectionID, connID protocol.ConnectionID,
version protocol.VersionNumber, version protocol.VersionNumber,
@ -76,6 +77,7 @@ func NewCryptoSetupClient(
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
return &cryptoSetupClient{ return &cryptoSetupClient{
cryptoStream: cryptoStream,
hostname: hostname, hostname: hostname,
connID: connID, connID: connID,
version: version, version: version,
@ -91,12 +93,10 @@ func NewCryptoSetupClient(
}, nil }, nil
} }
func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error { func (h *cryptoSetupClient) HandleCryptoStream() error {
messageChan := make(chan HandshakeMessage) messageChan := make(chan HandshakeMessage)
errorChan := make(chan error) errorChan := make(chan error)
h.cryptoStream = stream
go func() { go func() {
for { for {
message, err := ParseHandshakeMessage(h.cryptoStream) message, err := ParseHandshakeMessage(h.cryptoStream)

View file

@ -113,6 +113,7 @@ var _ = Describe("Client Crypto Setup", func() {
paramsChan = make(chan TransportParameters, 1) paramsChan = make(chan TransportParameters, 1)
aeadChanged = make(chan protocol.EncryptionLevel, 2) aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupClient( csInt, err := NewCryptoSetupClient(
stream,
"hostname", "hostname",
0, 0,
version, version,
@ -144,13 +145,13 @@ var _ = Describe("Client Crypto Setup", func() {
It("rejects handshake messages with the wrong message tag", func() { It("rejects handshake messages with the wrong message tag", func() {
HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
}) })
It("errors on invalid handshake messages", func() { It("errors on invalid handshake messages", func() {
stream.dataToRead.Write([]byte("invalid message")) stream.dataToRead.Write([]byte("invalid message"))
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeFailed)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeFailed))
}) })
@ -159,7 +160,7 @@ var _ = Describe("Client Crypto Setup", func() {
stk := []byte("foobar") stk := []byte("foobar")
tagMap[TagSTK] = stk tagMap[TagSTK] = stk
HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead)
go cs.HandleCryptoStream(stream) go cs.HandleCryptoStream()
Eventually(func() []byte { return cs.stk }).Should(Equal(stk)) Eventually(func() []byte { return cs.stk }).Should(Equal(stk))
}) })
@ -445,7 +446,7 @@ var _ = Describe("Client Crypto Setup", func() {
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}() }()
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure))) Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure)))
@ -457,7 +458,7 @@ var _ = Describe("Client Crypto Setup", func() {
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}() }()
var params TransportParameters var params TransportParameters
@ -682,7 +683,7 @@ var _ = Describe("Client Crypto Setup", func() {
It("tries to escalate before reading a handshake message", func() { It("tries to escalate before reading a handshake message", func() {
Expect(cs.secureAEAD).To(BeNil()) Expect(cs.secureAEAD).To(BeNil())
cs.serverVerified = true cs.serverVerified = true
go cs.HandleCryptoStream(stream) go cs.HandleCryptoStream()
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive()) Expect(aeadChanged).ToNot(Receive())
@ -692,7 +693,7 @@ var _ = Describe("Client Crypto Setup", func() {
It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) { It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cs.HandleCryptoStream(stream) cs.HandleCryptoStream()
Fail("HandleCryptoStream should not have returned") Fail("HandleCryptoStream should not have returned")
}() }()
cs.diversificationNonce = nil cs.diversificationNonce = nil
@ -855,14 +856,14 @@ var _ = Describe("Client Crypto Setup", func() {
Context("Diversification Nonces", func() { Context("Diversification Nonces", func() {
It("sets a diversification nonce", func() { It("sets a diversification nonce", func() {
go cs.HandleCryptoStream(stream) go cs.HandleCryptoStream()
nonce := []byte("foobar") nonce := []byte("foobar")
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
}) })
It("doesn't do anything when called multiple times with the same nonce", func(done Done) { It("doesn't do anything when called multiple times with the same nonce", func(done Done) {
go cs.HandleCryptoStream(stream) go cs.HandleCryptoStream()
nonce := []byte("foobar") nonce := []byte("foobar")
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
@ -873,7 +874,7 @@ var _ = Describe("Client Crypto Setup", func() {
It("rejects a different diversification nonce", func() { It("rejects a different diversification nonce", func() {
var err error var err error
go func() { go func() {
err = cs.HandleCryptoStream(stream) err = cs.HandleCryptoStream()
}() }()
nonce1 := []byte("foobar") nonce1 := []byte("foobar")

View file

@ -67,6 +67,7 @@ var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP exp
// NewCryptoSetup creates a new CryptoSetup instance for a server // NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup( func NewCryptoSetup(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID, connID protocol.ConnectionID,
remoteAddr net.Addr, remoteAddr net.Addr,
version protocol.VersionNumber, version protocol.VersionNumber,
@ -78,6 +79,7 @@ func NewCryptoSetup(
aeadChanged chan<- protocol.EncryptionLevel, aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
return &cryptoSetupServer{ return &cryptoSetupServer{
cryptoStream: cryptoStream,
connID: connID, connID: connID,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
version: version, version: version,
@ -95,9 +97,7 @@ func NewCryptoSetup(
} }
// HandleCryptoStream reads and writes messages on the crypto stream // HandleCryptoStream reads and writes messages on the crypto stream
func (h *cryptoSetupServer) HandleCryptoStream(stream io.ReadWriter) error { func (h *cryptoSetupServer) HandleCryptoStream() error {
h.cryptoStream = stream
for { for {
var chloData bytes.Buffer var chloData bytes.Buffer
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData)) message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))

View file

@ -202,6 +202,7 @@ var _ = Describe("Server Crypto Setup", func() {
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1] version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
supportedVersions = []protocol.VersionNumber{version, 98, 99} supportedVersions = []protocol.VersionNumber{version, 98, 99}
csInt, err := NewCryptoSetup( csInt, err := NewCryptoSetup(
stream,
protocol.ConnectionID(42), protocol.ConnectionID(42),
remoteAddr, remoteAddr,
version, version,
@ -275,7 +276,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagFHL2: []byte("foobar"), TagFHL2: []byte("foobar"),
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(ErrHOLExperiment)) Expect(err).To(MatchError(ErrHOLExperiment))
}) })
@ -286,7 +287,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagNSTP: []byte("foobar"), TagNSTP: []byte("foobar"),
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(ErrNSTPExperiment)) Expect(err).To(MatchError(ErrNSTPExperiment))
}) })
@ -369,7 +370,7 @@ var _ = Describe("Server Crypto Setup", func() {
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
@ -381,14 +382,14 @@ var _ = Describe("Server Crypto Setup", func() {
It("rejects client nonces that have the wrong length", func() { It("rejects client nonces that have the wrong length", func() {
fullCHLO[TagNONC] = []byte("too short client nonce") fullCHLO[TagNONC] = []byte("too short client nonce")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")))
}) })
It("rejects client nonces that have the wrong OBIT value", func() { It("rejects client nonces that have the wrong OBIT value", func() {
fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0 fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")))
}) })
@ -396,13 +397,13 @@ var _ = Describe("Server Crypto Setup", func() {
testErr := errors.New("test error") testErr := errors.New("test error")
kex.sharedKeyError = testErr kex.sharedKeyError = testErr
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
It("handles 0-RTT handshake", func() { It("handles 0-RTT handshake", func() {
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
@ -459,14 +460,14 @@ var _ = Describe("Server Crypto Setup", func() {
TagSNI: []byte("quic.clemente.io"), TagSNI: []byte("quic.clemente.io"),
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")))
}) })
It("rejects CHLOs with a version tag that has the wrong length", func() { It("rejects CHLOs with a version tag that has the wrong length", func() {
fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")))
}) })
@ -479,7 +480,7 @@ var _ = Describe("Server Crypto Setup", func() {
binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(lowestSupportedVersion)) binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(lowestSupportedVersion))
fullCHLO[TagVER] = b fullCHLO[TagVER] = b
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected"))) Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")))
}) })
@ -492,35 +493,35 @@ var _ = Describe("Server Crypto Setup", func() {
binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(unsupportedVersion)) binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(unsupportedVersion))
fullCHLO[TagVER] = b fullCHLO[TagVER] = b
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("errors if the AEAD tag is missing", func() { It("errors if the AEAD tag is missing", func() {
delete(fullCHLO, TagAEAD) delete(fullCHLO, TagAEAD)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
}) })
It("errors if the AEAD tag has the wrong value", func() { It("errors if the AEAD tag has the wrong value", func() {
fullCHLO[TagAEAD] = []byte("wrong") fullCHLO[TagAEAD] = []byte("wrong")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
}) })
It("errors if the KEXS tag is missing", func() { It("errors if the KEXS tag is missing", func() {
delete(fullCHLO, TagKEXS) delete(fullCHLO, TagKEXS)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
}) })
It("errors if the KEXS tag has the wrong value", func() { It("errors if the KEXS tag has the wrong value", func() {
fullCHLO[TagKEXS] = []byte("wrong") fullCHLO[TagKEXS] = []byte("wrong")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
}) })
}) })
@ -532,7 +533,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagSTK: validSTK, TagSTK: validSTK,
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
}) })
@ -544,19 +545,19 @@ var _ = Describe("Server Crypto Setup", func() {
TagSNI: nil, TagSNI: nil,
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
}) })
It("errors with invalid message", func() { It("errors with invalid message", func() {
stream.dataToRead.Write([]byte("invalid message")) stream.dataToRead.Write([]byte("invalid message"))
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.HandshakeFailed)) Expect(err).To(MatchError(qerr.HandshakeFailed))
}) })
It("errors with non-CHLO message", func() { It("errors with non-CHLO message", func() {
HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream(stream) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
}) })

View file

@ -9,7 +9,6 @@ import (
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
) )
// KeyDerivationFunction is used for key derivation // KeyDerivationFunction is used for key derivation
@ -22,7 +21,7 @@ type cryptoSetupTLS struct {
keyDerivation KeyDerivationFunction keyDerivation KeyDerivationFunction
mintConf *mint.Config conn *mint.Conn
extensionHandler mint.AppExtensionHandler extensionHandler mint.AppExtensionHandler
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
@ -37,6 +36,7 @@ var newMintController = func(conn *mint.Conn) crypto.MintController {
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer( func NewCryptoSetupTLSServer(
cryptoStream io.ReadWriter,
tlsConfig *tls.Config, tlsConfig *tls.Config,
params *TransportParameters, params *TransportParameters,
paramsChan chan<- TransportParameters, paramsChan chan<- TransportParameters,
@ -51,7 +51,7 @@ func NewCryptoSetupTLSServer(
return &cryptoSetupTLS{ return &cryptoSetupTLS{
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
mintConf: mintConf, conn: mint.Server(&fakeConn{cryptoStream}, mintConf),
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
@ -61,6 +61,7 @@ func NewCryptoSetupTLSServer(
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client // NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
func NewCryptoSetupTLSClient( func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter,
hostname string, // only needed for the client hostname string, // only needed for the client
tlsConfig *tls.Config, tlsConfig *tls.Config,
params *TransportParameters, params *TransportParameters,
@ -78,7 +79,7 @@ func NewCryptoSetupTLSClient(
return &cryptoSetupTLS{ return &cryptoSetupTLS{
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
mintConf: mintConf, conn: mint.Client(&fakeConn{cryptoStream}, mintConf),
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
@ -86,18 +87,11 @@ func NewCryptoSetupTLSClient(
}, nil }, nil
} }
func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error { func (h *cryptoSetupTLS) HandleCryptoStream() error {
var conn *mint.Conn if err := h.conn.SetExtensionHandler(h.extensionHandler); err != nil {
if h.perspective == protocol.PerspectiveServer {
conn = mint.Server(&fakeConn{cryptoStream}, h.mintConf)
} else {
conn = mint.Client(&fakeConn{cryptoStream}, h.mintConf)
}
utils.Debugf("setting extension handler: %#v\n", h.extensionHandler)
if err := conn.SetExtensionHandler(h.extensionHandler); err != nil {
return err return err
} }
mc := newMintController(conn) mc := newMintController(h.conn)
if alert := mc.Handshake(); alert != mint.AlertNoAlert { if alert := mc.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)

View file

@ -43,6 +43,7 @@ var _ = Describe("TLS Crypto Setup", func() {
paramsChan = make(chan TransportParameters) paramsChan = make(chan TransportParameters)
aeadChanged = make(chan protocol.EncryptionLevel, 2) aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupTLSServer( csInt, err := NewCryptoSetupTLSServer(
nil,
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
&TransportParameters{}, &TransportParameters{},
paramsChan, paramsChan,
@ -63,7 +64,7 @@ var _ = Describe("TLS Crypto Setup", func() {
newMintController = func(*mint.Conn) crypto.MintController { newMintController = func(*mint.Conn) crypto.MintController {
return &fakeMintController{result: alert} return &fakeMintController{result: alert}
} }
err := cs.HandleCryptoStream(nil) err := cs.HandleCryptoStream()
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert))) Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
}) })
@ -72,7 +73,7 @@ var _ = Describe("TLS Crypto Setup", func() {
return &fakeMintController{result: mint.AlertNoAlert} return &fakeMintController{result: mint.AlertNoAlert}
} }
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream(nil) err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
Expect(aeadChanged).To(BeClosed()) Expect(aeadChanged).To(BeClosed())
@ -86,7 +87,7 @@ var _ = Describe("TLS Crypto Setup", func() {
return &fakeMintController{result: mint.AlertNoAlert} return &fakeMintController{result: mint.AlertNoAlert}
} }
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream(nil) err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }

View file

@ -1,8 +1,6 @@
package handshake package handshake
import ( import (
"io"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
@ -15,7 +13,7 @@ type Sealer interface {
// CryptoSetup is a crypto setup // CryptoSetup is a crypto setup
type CryptoSetup interface { type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream(io.ReadWriter) error HandleCryptoStream() error
// TODO: clean up this interface // TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient SetDiversificationNonce([]byte) // only needed for cryptoSetupClient

View file

@ -2,7 +2,6 @@ package quic
import ( import (
"bytes" "bytes"
"io"
"math" "math"
"github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/ackhandler"
@ -33,7 +32,7 @@ type mockCryptoSetup struct {
var _ handshake.CryptoSetup = &mockCryptoSetup{} var _ handshake.CryptoSetup = &mockCryptoSetup{}
func (m *mockCryptoSetup) HandleCryptoStream(io.ReadWriter) error { func (m *mockCryptoSetup) HandleCryptoStream() error {
return m.handleErr return m.handleErr
} }
func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {

View file

@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"time" "time"
@ -195,6 +196,21 @@ func (s *session) setup(
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
var cryptoStream io.ReadWriter
// open the crypto stream
if s.perspective == protocol.PerspectiveServer {
cryptoStream, _ = s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
} else {
cryptoStream, _ = s.OpenStream()
}
var err error var err error
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool { verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool {
@ -202,6 +218,7 @@ func (s *session) setup(
} }
if s.version.UsesTLS() { if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer( s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
cryptoStream,
tlsConf, tlsConf,
transportParams, transportParams,
paramsChan, paramsChan,
@ -211,6 +228,7 @@ func (s *session) setup(
) )
} else { } else {
s.cryptoSetup, err = newCryptoSetup( s.cryptoSetup, err = newCryptoSetup(
cryptoStream,
s.connectionID, s.connectionID,
s.conn.RemoteAddr(), s.conn.RemoteAddr(),
s.version, s.version,
@ -226,6 +244,7 @@ func (s *session) setup(
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
if s.version.UsesTLS() { if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient( s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
cryptoStream,
hostname, hostname,
tlsConf, tlsConf,
transportParams, transportParams,
@ -237,6 +256,7 @@ func (s *session) setup(
) )
} else { } else {
s.cryptoSetup, err = newCryptoSetupClient( s.cryptoSetup, err = newCryptoSetupClient(
cryptoStream,
hostname, hostname,
s.connectionID, s.connectionID,
s.version, s.version,
@ -252,12 +272,6 @@ func (s *session) setup(
return nil, nil, err return nil, nil, err
} }
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
s.streamFramer = newStreamFramer(s.streamsMap, s.connFlowController) s.streamFramer = newStreamFramer(s.streamsMap, s.connFlowController)
s.packer = newPacketPacker(s.connectionID, s.packer = newPacketPacker(s.connectionID,
s.cryptoSetup, s.cryptoSetup,
@ -267,14 +281,6 @@ func (s *session) setup(
) )
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
// open the crypto stream
if s.perspective == protocol.PerspectiveServer {
_, _ = s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
} else {
_, _ = s.OpenStream()
}
return s, handshakeChan, nil return s, handshakeChan, nil
} }
@ -282,10 +288,8 @@ func (s *session) setup(
func (s *session) run() error { func (s *session) run() error {
defer s.ctxCancel() defer s.ctxCancel()
// Start the crypto stream handler
go func() { go func() {
cryptoStream, _ := s.GetOrOpenStream(1) if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
if err := s.cryptoSetup.HandleCryptoStream(cryptoStream); err != nil {
s.Close(err) s.Close(err)
} }
}() }()

View file

@ -157,6 +157,7 @@ var _ = Describe("Session", func() {
cryptoSetup = &mockCryptoSetup{} cryptoSetup = &mockCryptoSetup{}
newCryptoSetup = func( newCryptoSetup = func(
_ io.ReadWriter,
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ net.Addr, _ net.Addr,
_ protocol.VersionNumber, _ protocol.VersionNumber,
@ -206,6 +207,7 @@ var _ = Describe("Session", func() {
BeforeEach(func() { BeforeEach(func() {
newCryptoSetup = func( newCryptoSetup = func(
_ io.ReadWriter,
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ net.Addr, _ net.Addr,
_ protocol.VersionNumber, _ protocol.VersionNumber,
@ -1512,6 +1514,7 @@ var _ = Describe("Client Session", func() {
cryptoSetup = &mockCryptoSetup{} cryptoSetup = &mockCryptoSetup{}
newCryptoSetupClient = func( newCryptoSetupClient = func(
_ io.ReadWriter,
_ string, _ string,
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ protocol.VersionNumber, _ protocol.VersionNumber,