mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
pass the crypto stream to the crypto setup constructor
The crypto stream is opened during the session setup. Passing it to the crypto setup directly helps simplify the constructor.
This commit is contained in:
parent
a88da29433
commit
282b423f7d
10 changed files with 75 additions and 74 deletions
|
@ -66,6 +66,7 @@ var (
|
|||
|
||||
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
||||
func NewCryptoSetupClient(
|
||||
cryptoStream io.ReadWriter,
|
||||
hostname string,
|
||||
connID protocol.ConnectionID,
|
||||
version protocol.VersionNumber,
|
||||
|
@ -76,6 +77,7 @@ func NewCryptoSetupClient(
|
|||
negotiatedVersions []protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
return &cryptoSetupClient{
|
||||
cryptoStream: cryptoStream,
|
||||
hostname: hostname,
|
||||
connID: connID,
|
||||
version: version,
|
||||
|
@ -91,12 +93,10 @@ func NewCryptoSetupClient(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) HandleCryptoStream(stream io.ReadWriter) error {
|
||||
func (h *cryptoSetupClient) HandleCryptoStream() error {
|
||||
messageChan := make(chan HandshakeMessage)
|
||||
errorChan := make(chan error)
|
||||
|
||||
h.cryptoStream = stream
|
||||
|
||||
go func() {
|
||||
for {
|
||||
message, err := ParseHandshakeMessage(h.cryptoStream)
|
||||
|
|
|
@ -113,6 +113,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
paramsChan = make(chan TransportParameters, 1)
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
csInt, err := NewCryptoSetupClient(
|
||||
stream,
|
||||
"hostname",
|
||||
0,
|
||||
version,
|
||||
|
@ -144,13 +145,13 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
|
||||
It("rejects handshake messages with the wrong message tag", func() {
|
||||
HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
|
||||
})
|
||||
|
||||
It("errors on invalid handshake messages", func() {
|
||||
stream.dataToRead.Write([]byte("invalid message"))
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeFailed))
|
||||
})
|
||||
|
@ -159,7 +160,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
stk := []byte("foobar")
|
||||
tagMap[TagSTK] = stk
|
||||
HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead)
|
||||
go cs.HandleCryptoStream(stream)
|
||||
go cs.HandleCryptoStream()
|
||||
Eventually(func() []byte { return cs.stk }).Should(Equal(stk))
|
||||
})
|
||||
|
||||
|
@ -445,7 +446,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
|
@ -457,7 +458,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
HandshakeMessage{Tag: TagSHLO, Data: shloMap}.Write(&stream.dataToRead)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}()
|
||||
var params TransportParameters
|
||||
|
@ -682,7 +683,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
It("tries to escalate before reading a handshake message", func() {
|
||||
Expect(cs.secureAEAD).To(BeNil())
|
||||
cs.serverVerified = true
|
||||
go cs.HandleCryptoStream(stream)
|
||||
go cs.HandleCryptoStream()
|
||||
Eventually(aeadChanged).Should(Receive(Equal(protocol.EncryptionSecure)))
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
Expect(aeadChanged).ToNot(Receive())
|
||||
|
@ -692,7 +693,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
It("tries to escalate the crypto after receiving a diversification nonce", func(done Done) {
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
cs.HandleCryptoStream(stream)
|
||||
cs.HandleCryptoStream()
|
||||
Fail("HandleCryptoStream should not have returned")
|
||||
}()
|
||||
cs.diversificationNonce = nil
|
||||
|
@ -855,14 +856,14 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
|
||||
Context("Diversification Nonces", func() {
|
||||
It("sets a diversification nonce", func() {
|
||||
go cs.HandleCryptoStream(stream)
|
||||
go cs.HandleCryptoStream()
|
||||
nonce := []byte("foobar")
|
||||
cs.SetDiversificationNonce(nonce)
|
||||
Eventually(func() []byte { return cs.diversificationNonce }).Should(Equal(nonce))
|
||||
})
|
||||
|
||||
It("doesn't do anything when called multiple times with the same nonce", func(done Done) {
|
||||
go cs.HandleCryptoStream(stream)
|
||||
go cs.HandleCryptoStream()
|
||||
nonce := []byte("foobar")
|
||||
cs.SetDiversificationNonce(nonce)
|
||||
cs.SetDiversificationNonce(nonce)
|
||||
|
@ -873,7 +874,7 @@ var _ = Describe("Client Crypto Setup", func() {
|
|||
It("rejects a different diversification nonce", func() {
|
||||
var err error
|
||||
go func() {
|
||||
err = cs.HandleCryptoStream(stream)
|
||||
err = cs.HandleCryptoStream()
|
||||
}()
|
||||
|
||||
nonce1 := []byte("foobar")
|
||||
|
|
|
@ -67,6 +67,7 @@ var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP exp
|
|||
|
||||
// NewCryptoSetup creates a new CryptoSetup instance for a server
|
||||
func NewCryptoSetup(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
remoteAddr net.Addr,
|
||||
version protocol.VersionNumber,
|
||||
|
@ -78,6 +79,7 @@ func NewCryptoSetup(
|
|||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
) (CryptoSetup, error) {
|
||||
return &cryptoSetupServer{
|
||||
cryptoStream: cryptoStream,
|
||||
connID: connID,
|
||||
remoteAddr: remoteAddr,
|
||||
version: version,
|
||||
|
@ -95,9 +97,7 @@ func NewCryptoSetup(
|
|||
}
|
||||
|
||||
// HandleCryptoStream reads and writes messages on the crypto stream
|
||||
func (h *cryptoSetupServer) HandleCryptoStream(stream io.ReadWriter) error {
|
||||
h.cryptoStream = stream
|
||||
|
||||
func (h *cryptoSetupServer) HandleCryptoStream() error {
|
||||
for {
|
||||
var chloData bytes.Buffer
|
||||
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
|
||||
|
|
|
@ -202,6 +202,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
||||
supportedVersions = []protocol.VersionNumber{version, 98, 99}
|
||||
csInt, err := NewCryptoSetup(
|
||||
stream,
|
||||
protocol.ConnectionID(42),
|
||||
remoteAddr,
|
||||
version,
|
||||
|
@ -275,7 +276,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
TagFHL2: []byte("foobar"),
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(ErrHOLExperiment))
|
||||
})
|
||||
|
||||
|
@ -286,7 +287,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
TagNSTP: []byte("foobar"),
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(ErrNSTPExperiment))
|
||||
})
|
||||
|
||||
|
@ -369,7 +370,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionSecure)))
|
||||
|
@ -381,14 +382,14 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
It("rejects client nonces that have the wrong length", func() {
|
||||
fullCHLO[TagNONC] = []byte("too short client nonce")
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")))
|
||||
})
|
||||
|
||||
It("rejects client nonces that have the wrong OBIT value", func() {
|
||||
fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")))
|
||||
})
|
||||
|
||||
|
@ -396,13 +397,13 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
testErr := errors.New("test error")
|
||||
kex.sharedKeyError = testErr
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("handles 0-RTT handshake", func() {
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
|
||||
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
|
||||
|
@ -459,14 +460,14 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
TagSNI: []byte("quic.clemente.io"),
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")))
|
||||
})
|
||||
|
||||
It("rejects CHLOs with a version tag that has the wrong length", func() {
|
||||
fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")))
|
||||
})
|
||||
|
||||
|
@ -479,7 +480,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(lowestSupportedVersion))
|
||||
fullCHLO[TagVER] = b
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")))
|
||||
})
|
||||
|
||||
|
@ -492,35 +493,35 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(unsupportedVersion))
|
||||
fullCHLO[TagVER] = b
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors if the AEAD tag is missing", func() {
|
||||
delete(fullCHLO, TagAEAD)
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
||||
})
|
||||
|
||||
It("errors if the AEAD tag has the wrong value", func() {
|
||||
fullCHLO[TagAEAD] = []byte("wrong")
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
||||
})
|
||||
|
||||
It("errors if the KEXS tag is missing", func() {
|
||||
delete(fullCHLO, TagKEXS)
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
||||
})
|
||||
|
||||
It("errors if the KEXS tag has the wrong value", func() {
|
||||
fullCHLO[TagKEXS] = []byte("wrong")
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
||||
})
|
||||
})
|
||||
|
@ -532,7 +533,7 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
TagSTK: validSTK,
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
|
||||
})
|
||||
|
||||
|
@ -544,19 +545,19 @@ var _ = Describe("Server Crypto Setup", func() {
|
|||
TagSNI: nil,
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
|
||||
})
|
||||
|
||||
It("errors with invalid message", func() {
|
||||
stream.dataToRead.Write([]byte("invalid message"))
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.HandshakeFailed))
|
||||
})
|
||||
|
||||
It("errors with non-CHLO message", func() {
|
||||
HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead)
|
||||
err := cs.HandleCryptoStream(stream)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
|
||||
})
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// KeyDerivationFunction is used for key derivation
|
||||
|
@ -22,7 +21,7 @@ type cryptoSetupTLS struct {
|
|||
|
||||
keyDerivation KeyDerivationFunction
|
||||
|
||||
mintConf *mint.Config
|
||||
conn *mint.Conn
|
||||
extensionHandler mint.AppExtensionHandler
|
||||
|
||||
nullAEAD crypto.AEAD
|
||||
|
@ -37,6 +36,7 @@ var newMintController = func(conn *mint.Conn) crypto.MintController {
|
|||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
func NewCryptoSetupTLSServer(
|
||||
cryptoStream io.ReadWriter,
|
||||
tlsConfig *tls.Config,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
|
@ -51,7 +51,7 @@ func NewCryptoSetupTLSServer(
|
|||
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveServer,
|
||||
mintConf: mintConf,
|
||||
conn: mint.Server(&fakeConn{cryptoStream}, mintConf),
|
||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version),
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
|
@ -61,6 +61,7 @@ func NewCryptoSetupTLSServer(
|
|||
|
||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
||||
func NewCryptoSetupTLSClient(
|
||||
cryptoStream io.ReadWriter,
|
||||
hostname string, // only needed for the client
|
||||
tlsConfig *tls.Config,
|
||||
params *TransportParameters,
|
||||
|
@ -78,7 +79,7 @@ func NewCryptoSetupTLSClient(
|
|||
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveClient,
|
||||
mintConf: mintConf,
|
||||
conn: mint.Client(&fakeConn{cryptoStream}, mintConf),
|
||||
nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveClient, version),
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
|
@ -86,18 +87,11 @@ func NewCryptoSetupTLSClient(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) HandleCryptoStream(cryptoStream io.ReadWriter) error {
|
||||
var conn *mint.Conn
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
conn = mint.Server(&fakeConn{cryptoStream}, h.mintConf)
|
||||
} else {
|
||||
conn = mint.Client(&fakeConn{cryptoStream}, h.mintConf)
|
||||
}
|
||||
utils.Debugf("setting extension handler: %#v\n", h.extensionHandler)
|
||||
if err := conn.SetExtensionHandler(h.extensionHandler); err != nil {
|
||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||
if err := h.conn.SetExtensionHandler(h.extensionHandler); err != nil {
|
||||
return err
|
||||
}
|
||||
mc := newMintController(conn)
|
||||
mc := newMintController(h.conn)
|
||||
|
||||
if alert := mc.Handshake(); alert != mint.AlertNoAlert {
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
|
|
|
@ -43,6 +43,7 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
paramsChan = make(chan TransportParameters)
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
csInt, err := NewCryptoSetupTLSServer(
|
||||
nil,
|
||||
testdata.GetTLSConfig(),
|
||||
&TransportParameters{},
|
||||
paramsChan,
|
||||
|
@ -63,7 +64,7 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
newMintController = func(*mint.Conn) crypto.MintController {
|
||||
return &fakeMintController{result: alert}
|
||||
}
|
||||
err := cs.HandleCryptoStream(nil)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
|
||||
})
|
||||
|
||||
|
@ -72,7 +73,7 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
return &fakeMintController{result: mint.AlertNoAlert}
|
||||
}
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream(nil)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(aeadChanged).To(Receive(Equal(protocol.EncryptionForwardSecure)))
|
||||
Expect(aeadChanged).To(BeClosed())
|
||||
|
@ -86,7 +87,7 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
return &fakeMintController{result: mint.AlertNoAlert}
|
||||
}
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream(nil)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
|
@ -15,7 +13,7 @@ type Sealer interface {
|
|||
// CryptoSetup is a crypto setup
|
||||
type CryptoSetup interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
HandleCryptoStream(io.ReadWriter) error
|
||||
HandleCryptoStream() error
|
||||
// TODO: clean up this interface
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
|
|
|
@ -2,7 +2,6 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
|
@ -33,7 +32,7 @@ type mockCryptoSetup struct {
|
|||
|
||||
var _ handshake.CryptoSetup = &mockCryptoSetup{}
|
||||
|
||||
func (m *mockCryptoSetup) HandleCryptoStream(io.ReadWriter) error {
|
||||
func (m *mockCryptoSetup) HandleCryptoStream() error {
|
||||
return m.handleErr
|
||||
}
|
||||
func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||
|
|
38
session.go
38
session.go
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -195,6 +196,21 @@ func (s *session) setup(
|
|||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
|
||||
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.ReceiveConnectionFlowControlWindow,
|
||||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||
s.rttStats,
|
||||
)
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
|
||||
var cryptoStream io.ReadWriter
|
||||
// open the crypto stream
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
cryptoStream, _ = s.GetOrOpenStream(1)
|
||||
_, _ = s.AcceptStream() // don't expose the crypto stream
|
||||
} else {
|
||||
cryptoStream, _ = s.OpenStream()
|
||||
}
|
||||
|
||||
var err error
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool {
|
||||
|
@ -202,6 +218,7 @@ func (s *session) setup(
|
|||
}
|
||||
if s.version.UsesTLS() {
|
||||
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
|
||||
cryptoStream,
|
||||
tlsConf,
|
||||
transportParams,
|
||||
paramsChan,
|
||||
|
@ -211,6 +228,7 @@ func (s *session) setup(
|
|||
)
|
||||
} else {
|
||||
s.cryptoSetup, err = newCryptoSetup(
|
||||
cryptoStream,
|
||||
s.connectionID,
|
||||
s.conn.RemoteAddr(),
|
||||
s.version,
|
||||
|
@ -226,6 +244,7 @@ func (s *session) setup(
|
|||
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
|
||||
if s.version.UsesTLS() {
|
||||
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
|
||||
cryptoStream,
|
||||
hostname,
|
||||
tlsConf,
|
||||
transportParams,
|
||||
|
@ -237,6 +256,7 @@ func (s *session) setup(
|
|||
)
|
||||
} else {
|
||||
s.cryptoSetup, err = newCryptoSetupClient(
|
||||
cryptoStream,
|
||||
hostname,
|
||||
s.connectionID,
|
||||
s.version,
|
||||
|
@ -252,12 +272,6 @@ func (s *session) setup(
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.ReceiveConnectionFlowControlWindow,
|
||||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||
s.rttStats,
|
||||
)
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective)
|
||||
s.streamFramer = newStreamFramer(s.streamsMap, s.connFlowController)
|
||||
s.packer = newPacketPacker(s.connectionID,
|
||||
s.cryptoSetup,
|
||||
|
@ -267,14 +281,6 @@ func (s *session) setup(
|
|||
)
|
||||
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
|
||||
|
||||
// open the crypto stream
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
_, _ = s.GetOrOpenStream(1)
|
||||
_, _ = s.AcceptStream() // don't expose the crypto stream
|
||||
} else {
|
||||
_, _ = s.OpenStream()
|
||||
}
|
||||
|
||||
return s, handshakeChan, nil
|
||||
}
|
||||
|
||||
|
@ -282,10 +288,8 @@ func (s *session) setup(
|
|||
func (s *session) run() error {
|
||||
defer s.ctxCancel()
|
||||
|
||||
// Start the crypto stream handler
|
||||
go func() {
|
||||
cryptoStream, _ := s.GetOrOpenStream(1)
|
||||
if err := s.cryptoSetup.HandleCryptoStream(cryptoStream); err != nil {
|
||||
if err := s.cryptoSetup.HandleCryptoStream(); err != nil {
|
||||
s.Close(err)
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -157,6 +157,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
cryptoSetup = &mockCryptoSetup{}
|
||||
newCryptoSetup = func(
|
||||
_ io.ReadWriter,
|
||||
_ protocol.ConnectionID,
|
||||
_ net.Addr,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -206,6 +207,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
BeforeEach(func() {
|
||||
newCryptoSetup = func(
|
||||
_ io.ReadWriter,
|
||||
_ protocol.ConnectionID,
|
||||
_ net.Addr,
|
||||
_ protocol.VersionNumber,
|
||||
|
@ -1512,6 +1514,7 @@ var _ = Describe("Client Session", func() {
|
|||
|
||||
cryptoSetup = &mockCryptoSetup{}
|
||||
newCryptoSetupClient = func(
|
||||
_ io.ReadWriter,
|
||||
_ string,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.VersionNumber,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue