initialize the connection parameters manager in the crypto setup

This commit is contained in:
Marten Seemann 2017-09-25 14:21:32 +07:00
parent 565869723a
commit 95901cdee4
12 changed files with 234 additions and 174 deletions

View file

@ -66,9 +66,7 @@ var (
func NewConnectionParamatersManager( func NewConnectionParamatersManager(
pers protocol.Perspective, pers protocol.Perspective,
v protocol.VersionNumber, v protocol.VersionNumber,
maxReceiveStreamFlowControlWindow protocol.ByteCount, params *TransportParameters,
maxReceiveConnectionFlowControlWindow protocol.ByteCount,
idleTimeout time.Duration,
) ConnectionParametersManager { ) ConnectionParametersManager {
h := &connectionParametersManager{ h := &connectionParametersManager{
perspective: pers, perspective: pers,
@ -77,11 +75,11 @@ func NewConnectionParamatersManager(
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, maxReceiveStreamFlowControlWindow: params.MaxReceiveStreamFlowControlWindow,
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, maxReceiveConnectionFlowControlWindow: params.MaxReceiveConnectionFlowControlWindow,
} }
h.idleConnectionStateLifetime = idleTimeout h.idleConnectionStateLifetime = params.IdleTimeout
if h.perspective == protocol.PerspectiveServer { if h.perspective == protocol.PerspectiveServer {
h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent h.maxStreamsPerConnection = protocol.MaxStreamsPerConnection // this is the value negotiated based on what the client sent
h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective h.maxIncomingDynamicStreamsPerConnection = protocol.MaxStreamsPerConnection // "incoming" seen from the client's perspective

View file

@ -23,16 +23,20 @@ var _ = Describe("ConnectionsParameterManager", func() {
cpm = NewConnectionParamatersManager( cpm = NewConnectionParamatersManager(
protocol.PerspectiveServer, protocol.PerspectiveServer,
protocol.VersionWhatever, protocol.VersionWhatever,
maxReceiveStreamFlowControlWindowServer, &TransportParameters{
maxReceiveConnectionFlowControlWindowServer, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindowServer,
idleTimeout, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindowServer,
IdleTimeout: idleTimeout,
},
).(*connectionParametersManager) ).(*connectionParametersManager)
cpmClient = NewConnectionParamatersManager( cpmClient = NewConnectionParamatersManager(
protocol.PerspectiveClient, protocol.PerspectiveClient,
protocol.VersionWhatever, protocol.VersionWhatever,
maxReceiveStreamFlowControlWindowClient, &TransportParameters{
maxReceiveConnectionFlowControlWindowClient, MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindowClient,
idleTimeout, MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindowClient,
IdleTimeout: idleTimeout,
},
).(*connectionParametersManager) ).(*connectionParametersManager)
}) })

View file

@ -51,8 +51,8 @@ type cryptoSetupClient struct {
forwardSecureAEAD crypto.AEAD forwardSecureAEAD crypto.AEAD
aeadChanged chan<- protocol.EncryptionLevel aeadChanged chan<- protocol.EncryptionLevel
params *TransportParameters requestConnIDTruncation bool
connectionParameters ConnectionParametersManager connectionParameters ConnectionParametersManager
} }
var _ CryptoSetup = &cryptoSetupClient{} var _ CryptoSetup = &cryptoSetupClient{}
@ -68,34 +68,34 @@ func NewCryptoSetupClient(
hostname string, hostname string,
connID protocol.ConnectionID, connID protocol.ConnectionID,
version protocol.VersionNumber, version protocol.VersionNumber,
cryptoStream io.ReadWriter,
tlsConfig *tls.Config, tlsConfig *tls.Config,
connectionParameters ConnectionParametersManager,
aeadChanged chan<- protocol.EncryptionLevel,
params *TransportParameters, params *TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetup, ConnectionParametersManager, error) {
cpm := NewConnectionParamatersManager(protocol.PerspectiveClient, version, params)
return &cryptoSetupClient{ return &cryptoSetupClient{
hostname: hostname, hostname: hostname,
connID: connID, connID: connID,
version: version, version: version,
cryptoStream: cryptoStream, certManager: crypto.NewCertManager(tlsConfig),
certManager: crypto.NewCertManager(tlsConfig), connectionParameters: cpm,
connectionParameters: connectionParameters, requestConnIDTruncation: params.RequestConnectionIDTruncation,
keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX, keyExchange: getEphermalKEX,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version), nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
negotiatedVersions: negotiatedVersions, negotiatedVersions: negotiatedVersions,
divNonceChan: make(chan []byte), divNonceChan: make(chan []byte),
params: params, }, cpm, nil
}, nil
} }
func (h *cryptoSetupClient) HandleCryptoStream() error { func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error {
messageChan := make(chan HandshakeMessage) messageChan := make(chan HandshakeMessage)
errorChan := make(chan error) errorChan := make(chan error)
h.cryptoStream = stream
go func() { go func() {
for { for {
message, err := ParseHandshakeMessage(h.cryptoStream) message, err := ParseHandshakeMessage(h.cryptoStream)
@ -401,7 +401,6 @@ func (h *cryptoSetupClient) sendCHLO() error {
} }
h.lastSentCHLO = b.Bytes() h.lastSentCHLO = b.Bytes()
return nil return nil
} }
@ -422,7 +421,7 @@ func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version)) binary.LittleEndian.PutUint32(versionTag, protocol.VersionNumberToTag(h.version))
tags[TagVER] = versionTag tags[TagVER] = versionTag
if h.params.RequestConnectionIDTruncation { if h.requestConnIDTruncation {
tags[TagTCID] = []byte{0, 0, 0, 0} tags[TagTCID] = []byte{0, 0, 0, 0}
} }
if len(h.stk) > 0 { if len(h.stk) > 0 {

View file

@ -109,20 +109,17 @@ var _ = Describe("Client Crypto Setup", func() {
certManager = &mockCertManager{} certManager = &mockCertManager{}
version := protocol.Version37 version := protocol.Version37
aeadChanged = make(chan protocol.EncryptionLevel, 2) aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupClient( csInt, _, err := NewCryptoSetupClient(
"hostname", "hostname",
0, 0,
version, version,
stream,
nil, nil,
NewConnectionParamatersManager( &TransportParameters{
protocol.PerspectiveClient, MaxReceiveStreamFlowControlWindow: protocol.DefaultMaxReceiveStreamFlowControlWindowClient,
version, MaxReceiveConnectionFlowControlWindow: protocol.DefaultMaxReceiveConnectionFlowControlWindowClient,
protocol.DefaultMaxReceiveStreamFlowControlWindowClient, protocol.DefaultMaxReceiveConnectionFlowControlWindowClient, IdleTimeout: protocol.DefaultIdleTimeout,
protocol.DefaultIdleTimeout, },
),
aeadChanged, aeadChanged,
&TransportParameters{},
nil, nil,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -131,6 +128,7 @@ var _ = Describe("Client Crypto Setup", func() {
cs.keyDerivation = keyDerivation cs.keyDerivation = keyDerivation
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
cs.nullAEAD = &mockAEAD{encLevel: protocol.EncryptionUnencrypted} cs.nullAEAD = &mockAEAD{encLevel: protocol.EncryptionUnencrypted}
cs.cryptoStream = stream
}) })
AfterEach(func() { AfterEach(func() {
@ -146,13 +144,13 @@ var _ = Describe("Client Crypto Setup", func() {
It("rejects handshake messages with the wrong message tag", func() { It("rejects handshake messages with the wrong message tag", func() {
HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
}) })
It("errors on invalid handshake messages", func() { It("errors on invalid handshake messages", func() {
stream.dataToRead.Write([]byte("invalid message")) stream.dataToRead.Write([]byte("invalid message"))
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeFailed)) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeFailed))
}) })
@ -161,7 +159,7 @@ var _ = Describe("Client Crypto Setup", func() {
stk := []byte("foobar") stk := []byte("foobar")
tagMap[TagSTK] = stk tagMap[TagSTK] = stk
HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead)
go cs.HandleCryptoStream() go cs.HandleCryptoStream(stream)
Eventually(func() []byte { return cs.stk }).Should(Equal(stk)) Eventually(func() []byte { return cs.stk }).Should(Equal(stk))
}) })
@ -494,7 +492,7 @@ var _ = Describe("Client Crypto Setup", func() {
}) })
It("requests to truncate the connection ID", func() { It("requests to truncate the connection ID", func() {
cs.params.RequestConnectionIDTruncation = true cs.requestConnIDTruncation = true
tags, err := cs.getTags() tags, err := cs.getTags()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(tags).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0})) Expect(tags).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
@ -663,7 +661,7 @@ var _ = Describe("Client Crypto Setup", func() {
It("tries to escalate before reading a handshake message", func() { It("tries to escalate before reading a handshake message", func() {
Expect(cs.secureAEAD).To(BeNil()) Expect(cs.secureAEAD).To(BeNil())
cs.serverVerified = true cs.serverVerified = true
go cs.HandleCryptoStream() go cs.HandleCryptoStream(stream)
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure))) Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
Expect(cs.secureAEAD).ToNot(BeNil()) Expect(cs.secureAEAD).ToNot(BeNil())
Expect(aeadChanged).ToNot(Receive()) Expect(aeadChanged).ToNot(Receive())
@ -673,7 +671,7 @@ var _ = Describe("Client Crypto Setup", func() {
It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) { It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) {
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
cs.HandleCryptoStream() cs.HandleCryptoStream(stream)
Fail("HandleCryptoStream should not have returned") Fail("HandleCryptoStream should not have returned")
}() }()
cs.diversificationNonce = nil cs.diversificationNonce = nil
@ -836,14 +834,14 @@ var _ = Describe("Client Crypto Setup", func() {
Context("Diversification Nonces", func() { Context("Diversification Nonces", func() {
It("sets a diversification nonce", func() { It("sets a diversification nonce", func() {
go cs.HandleCryptoStream() go cs.HandleCryptoStream(stream)
nonce := []byte("foobar") nonce := []byte("foobar")
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce)) Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
}) })
It("doesn't do anything when called multiple times with the same nonce", func(done Done) { It("doesn't do anything when called multiple times with the same nonce", func(done Done) {
go cs.HandleCryptoStream() go cs.HandleCryptoStream(stream)
nonce := []byte("foobar") nonce := []byte("foobar")
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
cs.SetDiversificationNonce(nonce) cs.SetDiversificationNonce(nonce)
@ -854,7 +852,7 @@ var _ = Describe("Client Crypto Setup", func() {
It("rejects a different diversification nonce", func() { It("rejects a different diversification nonce", func() {
var err error var err error
go func() { go func() {
err = cs.HandleCryptoStream() err = cs.HandleCryptoStream(stream)
}() }()
nonce1 := []byte("foobar") nonce1 := []byte("foobar")

View file

@ -69,17 +69,17 @@ func NewCryptoSetup(
remoteAddr net.Addr, remoteAddr net.Addr,
version protocol.VersionNumber, version protocol.VersionNumber,
scfg *ServerConfig, scfg *ServerConfig,
cryptoStream io.ReadWriter, params *TransportParameters,
connectionParametersManager ConnectionParametersManager,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool, acceptSTK func(net.Addr, *Cookie) bool,
aeadChanged chan<- protocol.EncryptionLevel, aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) { ) (CryptoSetup, ConnectionParametersManager, error) {
stkGenerator, err := NewCookieGenerator() stkGenerator, err := NewCookieGenerator()
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
cpm := NewConnectionParamatersManager(protocol.PerspectiveServer, version, params)
return &cryptoSetupServer{ return &cryptoSetupServer{
connID: connID, connID: connID,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
@ -90,16 +90,17 @@ func NewCryptoSetup(
keyDerivation: crypto.DeriveQuicCryptoAESKeys, keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX, keyExchange: getEphermalKEX,
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
cryptoStream: cryptoStream, connectionParameters: cpm,
connectionParameters: connectionParametersManager,
acceptSTKCallback: acceptSTK, acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}), sentSHLO: make(chan struct{}),
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
}, nil }, cpm, nil
} }
// HandleCryptoStream reads and writes messages on the crypto stream // HandleCryptoStream reads and writes messages on the crypto stream
func (h *cryptoSetupServer) HandleCryptoStream() error { func (h *cryptoSetupServer) HandleCryptoStream(stream io.ReadWriter) error {
h.cryptoStream = stream
for { for {
var chloData bytes.Buffer var chloData bytes.Buffer
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData)) message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))

View file

@ -167,7 +167,6 @@ var _ = Describe("Server Crypto Setup", func() {
scfg *ServerConfig scfg *ServerConfig
cs *cryptoSetupServer cs *cryptoSetupServer
stream *mockStream stream *mockStream
cpm ConnectionParametersManager
aeadChanged chan protocol.EncryptionLevel aeadChanged chan protocol.EncryptionLevel
nonce32 []byte nonce32 []byte
versionTag []byte versionTag []byte
@ -198,19 +197,16 @@ var _ = Describe("Server Crypto Setup", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1] version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
supportedVersions = []protocol.VersionNumber{version, 98, 99} supportedVersions = []protocol.VersionNumber{version, 98, 99}
cpm = NewConnectionParamatersManager( csInt, _, err := NewCryptoSetup(
protocol.PerspectiveServer,
protocol.VersionWhatever,
protocol.DefaultMaxReceiveStreamFlowControlWindowServer, protocol.DefaultMaxReceiveConnectionFlowControlWindowServer,
protocol.DefaultIdleTimeout,
)
csInt, err := NewCryptoSetup(
protocol.ConnectionID(42), protocol.ConnectionID(42),
remoteAddr, remoteAddr,
version, version,
scfg, scfg,
stream, &TransportParameters{
cpm, MaxReceiveStreamFlowControlWindow: protocol.DefaultMaxReceiveStreamFlowControlWindowServer,
MaxReceiveConnectionFlowControlWindow: protocol.DefaultMaxReceiveConnectionFlowControlWindowServer,
IdleTimeout: protocol.DefaultIdleTimeout,
},
supportedVersions, supportedVersions,
nil, nil,
aeadChanged, aeadChanged,
@ -225,6 +221,7 @@ var _ = Describe("Server Crypto Setup", func() {
cs.keyDerivation = mockQuicCryptoKeyDerivation cs.keyDerivation = mockQuicCryptoKeyDerivation
cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} }
cs.nullAEAD = &mockAEAD{encLevel: protocol.EncryptionUnencrypted} cs.nullAEAD = &mockAEAD{encLevel: protocol.EncryptionUnencrypted}
cs.cryptoStream = stream
}) })
AfterEach(func() { AfterEach(func() {
@ -277,7 +274,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagFHL2: []byte("foobar"), TagFHL2: []byte("foobar"),
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(ErrHOLExperiment)) Expect(err).To(MatchError(ErrHOLExperiment))
}) })
@ -288,7 +285,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagNSTP: []byte("foobar"), TagNSTP: []byte("foobar"),
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(ErrNSTPExperiment)) Expect(err).To(MatchError(ErrNSTPExperiment))
}) })
@ -361,7 +358,7 @@ var _ = Describe("Server Crypto Setup", func() {
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure))) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
@ -373,14 +370,14 @@ var _ = Describe("Server Crypto Setup", func() {
It("rejects client nonces that have the wrong length", func() { It("rejects client nonces that have the wrong length", func() {
fullCHLO[TagNONC] = []byte("too short client nonce") fullCHLO[TagNONC] = []byte("too short client nonce")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")))
}) })
It("rejects client nonces that have the wrong OBIT value", func() { It("rejects client nonces that have the wrong OBIT value", func() {
fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0 fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")))
}) })
@ -388,13 +385,13 @@ var _ = Describe("Server Crypto Setup", func() {
testErr := errors.New("test error") testErr := errors.New("test error")
kex.sharedKeyError = testErr kex.sharedKeyError = testErr
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
It("handles 0-RTT handshake", func() { It("handles 0-RTT handshake", func() {
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
@ -451,14 +448,14 @@ var _ = Describe("Server Crypto Setup", func() {
TagSNI: []byte("quic.clemente.io"), TagSNI: []byte("quic.clemente.io"),
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")))
}) })
It("rejects CHLOs with a version tag that has the wrong length", func() { It("rejects CHLOs with a version tag that has the wrong length", func() {
fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag"))) Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")))
}) })
@ -471,7 +468,7 @@ var _ = Describe("Server Crypto Setup", func() {
binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(lowestSupportedVersion)) binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(lowestSupportedVersion))
fullCHLO[TagVER] = b fullCHLO[TagVER] = b
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected"))) Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")))
}) })
@ -484,35 +481,35 @@ var _ = Describe("Server Crypto Setup", func() {
binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(unsupportedVersion)) binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(unsupportedVersion))
fullCHLO[TagVER] = b fullCHLO[TagVER] = b
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("errors if the AEAD tag is missing", func() { It("errors if the AEAD tag is missing", func() {
delete(fullCHLO, TagAEAD) delete(fullCHLO, TagAEAD)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
}) })
It("errors if the AEAD tag has the wrong value", func() { It("errors if the AEAD tag has the wrong value", func() {
fullCHLO[TagAEAD] = []byte("wrong") fullCHLO[TagAEAD] = []byte("wrong")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
}) })
It("errors if the KEXS tag is missing", func() { It("errors if the KEXS tag is missing", func() {
delete(fullCHLO, TagKEXS) delete(fullCHLO, TagKEXS)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
}) })
It("errors if the KEXS tag has the wrong value", func() { It("errors if the KEXS tag has the wrong value", func() {
fullCHLO[TagKEXS] = []byte("wrong") fullCHLO[TagKEXS] = []byte("wrong")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
}) })
}) })
@ -524,7 +521,7 @@ var _ = Describe("Server Crypto Setup", func() {
TagSTK: validSTK, TagSTK: validSTK,
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
}) })
@ -536,19 +533,19 @@ var _ = Describe("Server Crypto Setup", func() {
TagSNI: nil, TagSNI: nil,
}, },
}.Write(&stream.dataToRead) }.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
}) })
It("errors with invalid message", func() { It("errors with invalid message", func() {
stream.dataToRead.Write([]byte("invalid message")) stream.dataToRead.Write([]byte("invalid message"))
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.HandshakeFailed)) Expect(err).To(MatchError(qerr.HandshakeFailed))
}) })
It("errors with non-CHLO message", func() { It("errors with non-CHLO message", func() {
HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead) HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(stream)
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
}) })

View file

@ -22,7 +22,6 @@ type cryptoSetupTLS struct {
keyDerivation KeyDerivationFunction keyDerivation KeyDerivationFunction
mintConf *mint.Config mintConf *mint.Config
conn crypto.MintController
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
aead crypto.AEAD aead crypto.AEAD
@ -30,43 +29,47 @@ type cryptoSetupTLS struct {
aeadChanged chan<- protocol.EncryptionLevel aeadChanged chan<- protocol.EncryptionLevel
} }
var newMintController = func(conn *mint.Conn) crypto.MintController {
return &mintController{conn}
}
// NewCryptoSetupTLS creates a new CryptoSetup instance for a server // NewCryptoSetupTLS creates a new CryptoSetup instance for a server
func NewCryptoSetupTLS( func NewCryptoSetupTLS(
hostname string, // only needed for the client hostname string, // only needed for the client
perspective protocol.Perspective, perspective protocol.Perspective,
version protocol.VersionNumber, version protocol.VersionNumber,
tlsConfig *tls.Config, tlsConfig *tls.Config,
cryptoStream io.ReadWriter,
aeadChanged chan<- protocol.EncryptionLevel, aeadChanged chan<- protocol.EncryptionLevel,
) (CryptoSetup, error) { ) (CryptoSetup, ConnectionParametersManager, error) {
mintConf, err := tlsToMintConfig(tlsConfig, perspective) mintConf, err := tlsToMintConfig(tlsConfig, perspective)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
mintConf.ServerName = hostname mintConf.ServerName = hostname
var conn *mint.Conn
if perspective == protocol.PerspectiveServer {
conn = mint.Server(&fakeConn{cryptoStream}, mintConf)
} else {
conn = mint.Client(&fakeConn{cryptoStream}, mintConf)
}
return &cryptoSetupTLS{ return &cryptoSetupTLS{
perspective: perspective, perspective: perspective,
mintConf: mintConf, mintConf: mintConf,
conn: &mintController{conn},
nullAEAD: crypto.NewNullAEAD(perspective, version), nullAEAD: crypto.NewNullAEAD(perspective, version),
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
}, nil }, NewConnectionParamatersManager(perspective, version, &TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout}), nil
} }
func (h *cryptoSetupTLS) HandleCryptoStream() error { func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error {
alert := h.conn.Handshake() var conn *mint.Conn
if alert != mint.AlertNoAlert { if h.perspective == protocol.PerspectiveServer {
conn = mint.Server(&fakeConn{cryptoStream}, h.mintConf)
} else {
conn = mint.Client(&fakeConn{cryptoStream}, h.mintConf)
}
mc := newMintController(conn)
if alert := mc.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
} }
aead, err := h.keyDerivation(h.conn, h.perspective) aead, err := h.keyDerivation(mc, h.perspective)
if err != nil { if err != nil {
return err return err
} }

View file

@ -34,33 +34,42 @@ var _ = Describe("TLS Crypto Setup", func() {
var ( var (
cs *cryptoSetupTLS cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel aeadChanged chan protocol.EncryptionLevel
mintControllerConstructor = newMintController
) )
BeforeEach(func() { BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2) aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupTLS( csInt, _, err := NewCryptoSetupTLS(
"", "",
protocol.PerspectiveServer, protocol.PerspectiveServer,
protocol.VersionTLS, protocol.VersionTLS,
testdata.GetTLSConfig(), testdata.GetTLSConfig(),
nil,
aeadChanged, aeadChanged,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
cs = csInt.(*cryptoSetupTLS) cs = csInt.(*cryptoSetupTLS)
}) })
AfterEach(func() {
newMintController = mintControllerConstructor
})
It("errors when the handshake fails", func() { It("errors when the handshake fails", func() {
alert := mint.AlertBadRecordMAC alert := mint.AlertBadRecordMAC
cs.conn = &fakeMintController{result: alert} newMintController = func(*mint.Conn) crypto.MintController {
err := cs.HandleCryptoStream() return &fakeMintController{result: alert}
}
err := cs.HandleCryptoStream(nil)
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert))) Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
}) })
It("derives keys", func() { It("derives keys", func() {
cs.conn = &fakeMintController{result: mint.AlertNoAlert} newMintController = func(*mint.Conn) crypto.MintController {
return &fakeMintController{result: mint.AlertNoAlert}
}
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure))) Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
Expect(aeadChanged).To(BeClosed()) Expect(aeadChanged).To(BeClosed())
@ -70,9 +79,11 @@ var _ = Describe("TLS Crypto Setup", func() {
var foobarFNVSigned []byte // a "foobar", FNV signed var foobarFNVSigned []byte // a "foobar", FNV signed
doHandshake := func() { doHandshake := func() {
cs.conn = &fakeMintController{result: mint.AlertNoAlert} newMintController = func(*mint.Conn) crypto.MintController {
return &fakeMintController{result: mint.AlertNoAlert}
}
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream(nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }

View file

@ -1,6 +1,11 @@
package handshake package handshake
import "github.com/lucas-clemente/quic-go/internal/protocol" import (
"io"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// Sealer seals a packet // Sealer seals a packet
type Sealer interface { type Sealer interface {
@ -11,7 +16,7 @@ type Sealer interface {
// CryptoSetup is a crypto setup // CryptoSetup is a crypto setup
type CryptoSetup interface { type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error HandleCryptoStream(io.ReadWriter) error
// TODO: clean up this interface // TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
@ -23,5 +28,8 @@ type CryptoSetup interface {
// TransportParameters are parameters sent to the peer during the handshake // TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct { type TransportParameters struct {
RequestConnectionIDTruncation bool RequestConnectionIDTruncation bool
MaxReceiveStreamFlowControlWindow protocol.ByteCount
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
IdleTimeout time.Duration
} }

View file

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"io"
"math" "math"
"github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/ackhandler"
@ -32,7 +33,7 @@ type mockCryptoSetup struct {
var _ handshake.CryptoSetup = &mockCryptoSetup{} var _ handshake.CryptoSetup = &mockCryptoSetup{}
func (m *mockCryptoSetup) HandleCryptoStream() error { func (m *mockCryptoSetup) HandleCryptoStream(io.ReadWriter) error {
return m.handleErr return m.handleErr
} }
func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {

View file

@ -180,69 +180,57 @@ func (s *session) setup(
s.sessionCreationTime = now s.sessionCreationTime = now
s.rttStats = &congestion.RTTStats{} s.rttStats = &congestion.RTTStats{}
s.connectionParameters = handshake.NewConnectionParamatersManager( transportParams := &handshake.TransportParameters{
s.perspective, MaxReceiveStreamFlowControlWindow: protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow),
s.version, MaxReceiveConnectionFlowControlWindow: protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), IdleTimeout: s.config.IdleTimeout,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), }
s.config.IdleTimeout,
)
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters)
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
var err error var err error
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
cryptoStream, _ := s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool { verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool {
return s.config.AcceptCookie(clientAddr, cookie) return s.config.AcceptCookie(clientAddr, cookie)
} }
if s.version.UsesTLS() { if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLS( s.cryptoSetup, s.connectionParameters, err = handshake.NewCryptoSetupTLS(
"", "",
s.perspective, s.perspective,
s.version, s.version,
tlsConf, tlsConf,
cryptoStream,
aeadChanged, aeadChanged,
) )
} else { } else {
s.cryptoSetup, err = newCryptoSetup( s.cryptoSetup, s.connectionParameters, err = newCryptoSetup(
s.connectionID, s.connectionID,
s.conn.RemoteAddr(), s.conn.RemoteAddr(),
s.version, s.version,
scfg, scfg,
cryptoStream, transportParams,
s.connectionParameters,
s.config.Versions, s.config.Versions,
verifySourceAddr, verifySourceAddr,
aeadChanged, aeadChanged,
) )
} }
} else { } else {
cryptoStream, _ := s.OpenStream()
if s.version.UsesTLS() { if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLS( s.cryptoSetup, s.connectionParameters, err = handshake.NewCryptoSetupTLS(
hostname, hostname,
s.perspective, s.perspective,
s.version, s.version,
tlsConf, tlsConf,
cryptoStream,
aeadChanged, aeadChanged,
) )
} else { } else {
s.cryptoSetup, err = newCryptoSetupClient( transportParams.RequestConnectionIDTruncation = s.config.RequestConnectionIDTruncation
s.cryptoSetup, s.connectionParameters, err = newCryptoSetupClient(
hostname, hostname,
s.connectionID, s.connectionID,
s.version, s.version,
cryptoStream,
tlsConf, tlsConf,
s.connectionParameters, transportParams,
aeadChanged, aeadChanged,
&handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation},
negotiatedVersions, negotiatedVersions,
) )
} }
@ -251,6 +239,9 @@ func (s *session) setup(
return nil, nil, err return nil, nil, err
} }
s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters)
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
s.packer = newPacketPacker(s.connectionID, s.packer = newPacketPacker(s.connectionID,
s.cryptoSetup, s.cryptoSetup,
s.connectionParameters, s.connectionParameters,
@ -266,8 +257,16 @@ func (s *session) setup(
// run the session main loop // run the session main loop
func (s *session) run() error { func (s *session) run() error {
// Start the crypto stream handler // Start the crypto stream handler
var cryptoStream Stream
if s.perspective == protocol.PerspectiveServer {
cryptoStream, _ = s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
} else {
cryptoStream, _ = s.OpenStream()
}
go func() { go func() {
if err := s.cryptoSetup.HandleCryptoStream(); err != nil { if err := s.cryptoSetup.HandleCryptoStream(cryptoStream); err != nil {
s.Close(err) s.Close(err)
} }
}() }()

View file

@ -143,12 +143,48 @@ func areSessionsRunning() bool {
return strings.Contains(b.String(), "quic-go.(*session).run") return strings.Contains(b.String(), "quic-go.(*session).run")
} }
type mockConnectionParametersManager struct {
}
func (m *mockConnectionParametersManager) SetFromMap(map[handshake.Tag][]byte) error {
panic("not implement")
}
func (m *mockConnectionParametersManager) GetHelloMap() (map[handshake.Tag][]byte, error) {
panic("not implement")
}
func (m *mockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
return protocol.InitialStreamFlowControlWindow
}
func (m *mockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
return protocol.InitialConnectionFlowControlWindow
}
func (m *mockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
return protocol.ReceiveStreamFlowControlWindow
}
func (m *mockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
return protocol.DefaultMaxReceiveStreamFlowControlWindowServer
}
func (m *mockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
return protocol.ReceiveConnectionFlowControlWindow
}
func (m *mockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
return protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
}
func (m *mockConnectionParametersManager) GetMaxOutgoingStreams() uint32 { return 100 }
func (m *mockConnectionParametersManager) GetMaxIncomingStreams() uint32 { return 100 }
func (m *mockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
return time.Hour
}
func (m *mockConnectionParametersManager) TruncateConnectionID() bool { return false }
var _ handshake.ConnectionParametersManager = &mockConnectionParametersManager{}
var _ = Describe("Session", func() { var _ = Describe("Session", func() {
var ( var (
sess *session sess *session
scfg *handshake.ServerConfig scfg *handshake.ServerConfig
mconn *mockConnection mconn *mockConnection
mockCpm *mocks.MockConnectionParametersManager
cryptoSetup *mockCryptoSetup cryptoSetup *mockCryptoSetup
handshakeChan <-chan handshakeEvent handshakeChan <-chan handshakeEvent
aeadChanged chan<- protocol.EncryptionLevel aeadChanged chan<- protocol.EncryptionLevel
@ -163,14 +199,13 @@ var _ = Describe("Session", func() {
_ net.Addr, _ net.Addr,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ *handshake.ServerConfig, _ *handshake.ServerConfig,
_ io.ReadWriter, _ *handshake.TransportParameters,
_ handshake.ConnectionParametersManager,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
_ func(net.Addr, *Cookie) bool, _ func(net.Addr, *Cookie) bool,
aeadChangedP chan<- protocol.EncryptionLevel, aeadChangedP chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) {
aeadChanged = aeadChangedP aeadChanged = aeadChangedP
return cryptoSetup, nil return cryptoSetup, &mockConnectionParametersManager{}, nil
} }
mconn = newMockConnection() mconn = newMockConnection()
@ -190,11 +225,9 @@ var _ = Describe("Session", func() {
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session) sess = pSess.(*session)
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream Expect(sess.streamsMap.openStreams).To(BeEmpty()) // the crypto stream is opened in session.run()
mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) sess.connectionParameters = &mockConnectionParametersManager{}
mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(time.Minute).AnyTimes()
sess.connectionParameters = mockCpm
}) })
AfterEach(func() { AfterEach(func() {
@ -216,14 +249,13 @@ var _ = Describe("Session", func() {
_ net.Addr, _ net.Addr,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ *handshake.ServerConfig, _ *handshake.ServerConfig,
_ io.ReadWriter, _ *handshake.TransportParameters,
_ handshake.ConnectionParametersManager,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
cookieFunc func(net.Addr, *Cookie) bool, cookieFunc func(net.Addr, *Cookie) bool,
_ chan<- protocol.EncryptionLevel, _ chan<- protocol.EncryptionLevel,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) {
cookieVerify = cookieFunc cookieVerify = cookieFunc
return cryptoSetup, nil return cryptoSetup, &mockConnectionParametersManager{}, nil
} }
conf := populateServerConfig(&Config{}) conf := populateServerConfig(&Config{})
@ -730,18 +762,23 @@ var _ = Describe("Session", func() {
Context("accepting streams", func() { Context("accepting streams", func() {
It("waits for new streams", func() { It("waits for new streams", func() {
var str Stream strChan := make(chan Stream)
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
var err error for {
str, err = sess.AcceptStream() str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
strChan <- str
}
}() }()
Consistently(func() Stream { return str }).Should(BeNil()) Consistently(strChan).ShouldNot(Receive())
sess.handleStreamFrame(&wire.StreamFrame{ sess.handleStreamFrame(&wire.StreamFrame{
StreamID: 3, StreamID: 3,
}) })
Eventually(func() Stream { return str }).ShouldNot(BeNil()) var str Stream
Eventually(strChan).Should(Receive(&str))
Expect(str.StreamID()).To(Equal(protocol.StreamID(1)))
Eventually(strChan).Should(Receive(&str))
Expect(str.StreamID()).To(Equal(protocol.StreamID(3))) Expect(str.StreamID()).To(Equal(protocol.StreamID(3)))
}) })
@ -944,7 +981,7 @@ var _ = Describe("Session", func() {
}) })
Context("sending packets", func() { Context("sending packets", func() {
It("sends ack frames", func() { It("sends ACK frames", func() {
packetNumber := protocol.PacketNumber(0x035E) packetNumber := protocol.PacketNumber(0x035E)
sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
err := sess.sendPacket() err := sess.sendPacket()
@ -1528,7 +1565,7 @@ var _ = Describe("Session", func() {
It("does not use ICSL before handshake", func() { It("does not use ICSL before handshake", func() {
defer sess.Close(nil) defer sess.Close(nil)
sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) sess.lastNetworkActivityTime = time.Now().Add(-time.Minute)
mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl)
mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(9999 * time.Second).AnyTimes() mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(9999 * time.Second).AnyTimes()
mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes() mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes()
sess.connectionParameters = mockCpm sess.connectionParameters = mockCpm
@ -1545,7 +1582,7 @@ var _ = Describe("Session", func() {
It("uses ICSL after handshake", func(done Done) { It("uses ICSL after handshake", func(done Done) {
close(aeadChanged) close(aeadChanged)
mockCpm = mocks.NewMockConnectionParametersManager(mockCtrl) mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl)
mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second) mockCpm.EXPECT().GetIdleConnectionStateLifetime().Return(0 * time.Second)
mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes() mockCpm.EXPECT().TruncateConnectionID().Return(false).AnyTimes()
sess.connectionParameters = mockCpm sess.connectionParameters = mockCpm
@ -1599,7 +1636,9 @@ var _ = Describe("Session", func() {
Context("counting streams", func() { Context("counting streams", func() {
It("errors when too many streams are opened", func() { It("errors when too many streams are opened", func() {
for i := 0; i < 110; i++ { mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl)
mockCpm.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes()
for i := 0; i < 10; i++ {
_, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) _, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
} }
@ -1608,6 +1647,8 @@ var _ = Describe("Session", func() {
}) })
It("does not error when many streams are opened and closed", func() { It("does not error when many streams are opened and closed", func() {
mockCpm := mocks.NewMockConnectionParametersManager(mockCtrl)
mockCpm.EXPECT().GetMaxIncomingStreams().Return(uint32(10)).AnyTimes()
for i := 2; i <= 1000; i++ { for i := 2; i <= 1000; i++ {
s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1))
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -1641,11 +1682,13 @@ var _ = Describe("Session", func() {
Context("window updates", func() { Context("window updates", func() {
It("gets stream level window updates", func() { It("gets stream level window updates", func() {
err := sess.flowControlManager.AddBytesRead(1, protocol.ReceiveStreamFlowControlWindow) _, err := sess.GetOrOpenStream(3)
Expect(err).ToNot(HaveOccurred())
err = sess.flowControlManager.AddBytesRead(3, protocol.ReceiveStreamFlowControlWindow)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
frames := sess.getWindowUpdateFrames() frames := sess.getWindowUpdateFrames()
Expect(frames).To(HaveLen(1)) Expect(frames).To(HaveLen(1))
Expect(frames[0].StreamID).To(Equal(protocol.StreamID(1))) Expect(frames[0].StreamID).To(Equal(protocol.StreamID(3)))
Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow * 2)) Expect(frames[0].ByteOffset).To(BeEquivalentTo(protocol.ReceiveStreamFlowControlWindow * 2))
}) })
@ -1691,15 +1734,13 @@ var _ = Describe("Client Session", func() {
_ string, _ string,
_ protocol.ConnectionID, _ protocol.ConnectionID,
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ io.ReadWriter,
_ *tls.Config, _ *tls.Config,
_ handshake.ConnectionParametersManager,
aeadChangedP chan<- protocol.EncryptionLevel,
_ *handshake.TransportParameters, _ *handshake.TransportParameters,
aeadChangedP chan<- protocol.EncryptionLevel,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
) (handshake.CryptoSetup, error) { ) (handshake.CryptoSetup, handshake.ConnectionParametersManager, error) {
aeadChanged = aeadChangedP aeadChanged = aeadChangedP
return cryptoSetup, nil return cryptoSetup, &mockConnectionParametersManager{}, nil
} }
mconn = newMockConnection() mconn = newMockConnection()
@ -1714,7 +1755,7 @@ var _ = Describe("Client Session", func() {
) )
sess = sessP.(*session) sess = sessP.(*session)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream Expect(sess.streamsMap.openStreams).To(BeEmpty()) // the crypto stream is opened in session.run()
}) })
AfterEach(func() { AfterEach(func() {