mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
implement stateless handling of Initial packets for the TLS server
This commit is contained in:
parent
57c6f3ceb5
commit
25a6dc9654
36 changed files with 1617 additions and 724 deletions
156
client.go
156
client.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
@ -23,14 +24,18 @@ type client struct {
|
||||||
hostname string
|
hostname string
|
||||||
|
|
||||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||||
versionNegotiated bool // has version negotiation completed yet
|
versionNegotiated bool // has the server accepted our version
|
||||||
receivedVersionNegotiationPacket bool
|
receivedVersionNegotiationPacket bool
|
||||||
|
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||||
|
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
config *Config
|
config *Config
|
||||||
|
tls handshake.MintTLS // only used when using TLS
|
||||||
|
|
||||||
connectionID protocol.ConnectionID
|
connectionID protocol.ConnectionID
|
||||||
version protocol.VersionNumber
|
|
||||||
|
initialVersion protocol.VersionNumber
|
||||||
|
version protocol.VersionNumber
|
||||||
|
|
||||||
session packetHandler
|
session packetHandler
|
||||||
}
|
}
|
||||||
|
@ -91,7 +96,6 @@ func DialNonFWSecure(
|
||||||
if tlsConf != nil {
|
if tlsConf != nil {
|
||||||
hostname = tlsConf.ServerName
|
hostname = tlsConf.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostname == "" {
|
if hostname == "" {
|
||||||
hostname, _, err = net.SplitHostPort(host)
|
hostname, _, err = net.SplitHostPort(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -111,8 +115,9 @@ func DialNonFWSecure(
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||||
|
go c.listen()
|
||||||
|
|
||||||
if err := c.establishSecureConnection(); err != nil {
|
if err := c.dial(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return c.session.(NonFWSession), nil
|
return c.session.(NonFWSession), nil
|
||||||
|
@ -177,25 +182,79 @@ func populateClientConfig(config *Config) *Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
func (c *client) dial() error {
|
||||||
func (c *client) establishSecureConnection() error {
|
var err error
|
||||||
if err := c.createNewSession(c.version, nil); err != nil {
|
if c.version.UsesTLS() {
|
||||||
|
err = c.dialTLS()
|
||||||
|
} else {
|
||||||
|
err = c.dialGQUIC()
|
||||||
|
}
|
||||||
|
if err == errCloseSessionForNewVersion {
|
||||||
|
return c.dial()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dialGQUIC() error {
|
||||||
|
if err := c.createNewGQUICSession(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go c.listen()
|
return c.establishSecureConnection()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) dialTLS() error {
|
||||||
|
csc := handshake.NewCryptoStreamConn(nil)
|
||||||
|
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mintConf.ServerName = c.hostname
|
||||||
|
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||||
|
params := &handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
MaxStreams: protocol.MaxIncomingStreams,
|
||||||
|
IdleTimeout: c.config.IdleTimeout,
|
||||||
|
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||||
|
}
|
||||||
|
eh := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
|
||||||
|
if err := c.tls.SetExtensionHandler(eh); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.establishSecureConnection(); err != nil {
|
||||||
|
if err != handshake.ErrCloseSessionForRetry {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
utils.Infof("Received a Retry packet. Recreating session.")
|
||||||
|
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.establishSecureConnection(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||||
|
// It returns:
|
||||||
|
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||||
|
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||||
|
// - any other error that might occur
|
||||||
|
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||||
|
func (c *client) establishSecureConnection() error {
|
||||||
var runErr error
|
var runErr error
|
||||||
errorChan := make(chan struct{})
|
errorChan := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
// session.run() returns as soon as the session is closed
|
runErr = c.session.run() // returns as soon as the session is closed
|
||||||
runErr = c.session.run()
|
|
||||||
if runErr == errCloseSessionForNewVersion {
|
|
||||||
// run the new session
|
|
||||||
runErr = c.session.run()
|
|
||||||
}
|
|
||||||
close(errorChan)
|
close(errorChan)
|
||||||
utils.Infof("Connection %x closed.", c.connectionID)
|
utils.Infof("Connection %x closed.", c.connectionID)
|
||||||
c.conn.Close()
|
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait until the server accepts the QUIC version (or an error occurs)
|
// wait until the server accepts the QUIC version (or an error occurs)
|
||||||
|
@ -219,7 +278,8 @@ func (c *client) establishSecureConnection() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen listens
|
// Listen listens on the underlying connection and passes packets on for handling.
|
||||||
|
// It returns when the connection is closed.
|
||||||
func (c *client) listen() {
|
func (c *client) listen() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
@ -233,13 +293,15 @@ func (c *client) listen() {
|
||||||
n, addr, err = c.conn.Read(data)
|
n, addr, err = c.conn.Read(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||||
c.session.Close(err)
|
c.mutex.Lock()
|
||||||
|
if c.session != nil {
|
||||||
|
c.session.Close(err)
|
||||||
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
data = data[:n]
|
c.handlePacket(addr, data[:n])
|
||||||
|
|
||||||
c.handlePacket(addr, data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,15 +319,16 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||||
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// reject packets with the wrong connection ID
|
|
||||||
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||||
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
|
// reject packets with the wrong connection ID
|
||||||
|
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if hdr.ResetFlag {
|
if hdr.ResetFlag {
|
||||||
cr := c.conn.RemoteAddr()
|
cr := c.conn.RemoteAddr()
|
||||||
// check if the remote address and the connection ID match
|
// check if the remote address and the connection ID match
|
||||||
|
@ -305,6 +368,8 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
||||||
close(c.versionNegotiationChan)
|
close(c.versionNegotiationChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
|
||||||
|
|
||||||
c.session.handlePacket(&receivedPacket{
|
c.session.handlePacket(&receivedPacket{
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
header: hdr,
|
header: hdr,
|
||||||
|
@ -323,15 +388,15 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.receivedVersionNegotiationPacket = true
|
|
||||||
|
|
||||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||||
if !ok {
|
if !ok {
|
||||||
return qerr.InvalidVersion
|
return qerr.InvalidVersion
|
||||||
}
|
}
|
||||||
|
c.receivedVersionNegotiationPacket = true
|
||||||
|
c.negotiatedVersions = hdr.SupportedVersions
|
||||||
|
|
||||||
// switch to negotiated version
|
// switch to negotiated version
|
||||||
initialVersion := c.version
|
c.initialVersion = c.version
|
||||||
c.version = newVersion
|
c.version = newVersion
|
||||||
var err error
|
var err error
|
||||||
c.connectionID, err = utils.GenerateConnectionID()
|
c.connectionID, err = utils.GenerateConnectionID()
|
||||||
|
@ -339,17 +404,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||||
|
c.session.Close(errCloseSessionForNewVersion)
|
||||||
// create a new session and close the old one
|
return nil
|
||||||
// the new session must be created first to update client member variables
|
|
||||||
oldSession := c.session
|
|
||||||
defer oldSession.Close(errCloseSessionForNewVersion)
|
|
||||||
return c.createNewSession(initialVersion, hdr.SupportedVersions)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error {
|
func (c *client) createNewGQUICSession() (err error) {
|
||||||
var err error
|
c.mutex.Lock()
|
||||||
utils.Debugf("createNewSession with initial version %s", initialVersion)
|
defer c.mutex.Unlock()
|
||||||
c.session, err = newClientSession(
|
c.session, err = newClientSession(
|
||||||
c.conn,
|
c.conn,
|
||||||
c.hostname,
|
c.hostname,
|
||||||
|
@ -357,8 +418,27 @@ func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotia
|
||||||
c.connectionID,
|
c.connectionID,
|
||||||
c.tlsConf,
|
c.tlsConf,
|
||||||
c.config,
|
c.config,
|
||||||
initialVersion,
|
c.initialVersion,
|
||||||
negotiatedVersions,
|
c.negotiatedVersions,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) createNewTLSSession(
|
||||||
|
paramsChan <-chan handshake.TransportParameters,
|
||||||
|
version protocol.VersionNumber,
|
||||||
|
) (err error) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
c.session, err = newTLSClientSession(
|
||||||
|
c.conn,
|
||||||
|
c.hostname,
|
||||||
|
c.version,
|
||||||
|
c.connectionID,
|
||||||
|
c.config,
|
||||||
|
c.tls,
|
||||||
|
paramsChan,
|
||||||
|
1,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
130
client_test.go
130
client_test.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
@ -45,10 +46,9 @@ var _ = Describe("Client", func() {
|
||||||
msess, _ := newMockSession(nil, 0, 0, nil, nil, nil)
|
msess, _ := newMockSession(nil, 0, 0, nil, nil, nil)
|
||||||
sess = msess.(*mockSession)
|
sess = msess.(*mockSession)
|
||||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||||
packetConn = &mockPacketConn{
|
packetConn = newMockPacketConn()
|
||||||
addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234},
|
packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
||||||
dataReadFrom: addr,
|
packetConn.dataReadFrom = addr
|
||||||
}
|
|
||||||
config = &Config{
|
config = &Config{
|
||||||
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
|
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ var _ = Describe("Client", func() {
|
||||||
_ protocol.VersionNumber,
|
_ protocol.VersionNumber,
|
||||||
_ []protocol.VersionNumber,
|
_ []protocol.VersionNumber,
|
||||||
) (packetHandler, error) {
|
) (packetHandler, error) {
|
||||||
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
|
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
origGenerateConnectionID = generateConnectionID
|
origGenerateConnectionID = generateConnectionID
|
||||||
|
@ -101,7 +101,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("dials non-forward-secure", func() {
|
It("dials non-forward-secure", func() {
|
||||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||||
dialed := make(chan struct{})
|
dialed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -151,7 +151,7 @@ var _ = Describe("Client", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("Dial only returns after the handshake is complete", func() {
|
It("Dial only returns after the handshake is complete", func() {
|
||||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||||
dialed := make(chan struct{})
|
dialed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -237,7 +237,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("returns an error that occurs while waiting for the connection to become secure", func() {
|
It("returns an error that occurs while waiting for the connection to become secure", func() {
|
||||||
testErr := errors.New("early handshake error")
|
testErr := errors.New("early handshake error")
|
||||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -251,7 +251,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
||||||
testErr := errors.New("late handshake error")
|
testErr := errors.New("late handshake error")
|
||||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -330,10 +330,6 @@ var _ = Describe("Client", func() {
|
||||||
newVersion := protocol.VersionNumber(77)
|
newVersion := protocol.VersionNumber(77)
|
||||||
Expect(newVersion).ToNot(Equal(cl.version))
|
Expect(newVersion).ToNot(Equal(cl.version))
|
||||||
Expect(config.Versions).To(ContainElement(newVersion))
|
Expect(config.Versions).To(ContainElement(newVersion))
|
||||||
packetConn.dataToRead = wire.ComposeGQUICVersionNegotiation(
|
|
||||||
cl.connectionID,
|
|
||||||
[]protocol.VersionNumber{newVersion},
|
|
||||||
)
|
|
||||||
sessionChan := make(chan *mockSession)
|
sessionChan := make(chan *mockSession)
|
||||||
handshakeChan := make(chan handshakeEvent)
|
handshakeChan := make(chan handshakeEvent)
|
||||||
newClientSession = func(
|
newClientSession = func(
|
||||||
|
@ -348,10 +344,7 @@ var _ = Describe("Client", func() {
|
||||||
) (packetHandler, error) {
|
) (packetHandler, error) {
|
||||||
initialVersion = initialVersionP
|
initialVersion = initialVersionP
|
||||||
negotiatedVersions = negotiatedVersionsP
|
negotiatedVersions = negotiatedVersionsP
|
||||||
// make the server accept the new version
|
|
||||||
if len(negotiatedVersionsP) > 0 {
|
|
||||||
packetConn.dataToRead = acceptClientVersionPacket(connectionID)
|
|
||||||
}
|
|
||||||
sess := &mockSession{
|
sess := &mockSession{
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
stopRunLoop: make(chan struct{}),
|
stopRunLoop: make(chan struct{}),
|
||||||
|
@ -364,18 +357,26 @@ var _ = Describe("Client", func() {
|
||||||
established := make(chan struct{})
|
established := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
err := cl.establishSecureConnection()
|
err := cl.dial()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
close(established)
|
close(established)
|
||||||
}()
|
}()
|
||||||
|
go cl.listen()
|
||||||
|
|
||||||
actualInitialVersion := cl.version
|
actualInitialVersion := cl.version
|
||||||
var firstSession, secondSession *mockSession
|
var firstSession, secondSession *mockSession
|
||||||
Eventually(sessionChan).Should(Receive(&firstSession))
|
Eventually(sessionChan).Should(Receive(&firstSession))
|
||||||
Eventually(sessionChan).Should(Receive(&secondSession))
|
packetConn.dataToRead <- wire.ComposeGQUICVersionNegotiation(
|
||||||
|
cl.connectionID,
|
||||||
|
[]protocol.VersionNumber{newVersion},
|
||||||
|
)
|
||||||
// it didn't pass the version negoation packet to the old session (since it has no payload)
|
// it didn't pass the version negoation packet to the old session (since it has no payload)
|
||||||
Expect(firstSession.packetCount).To(BeZero())
|
|
||||||
Eventually(func() bool { return firstSession.closed }).Should(BeTrue())
|
Eventually(func() bool { return firstSession.closed }).Should(BeTrue())
|
||||||
Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion))
|
Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion))
|
||||||
|
Expect(firstSession.packetCount).To(BeZero())
|
||||||
|
Eventually(sessionChan).Should(Receive(&secondSession))
|
||||||
|
// make the server accept the new version
|
||||||
|
packetConn.dataToRead <- acceptClientVersionPacket(secondSession.connectionID)
|
||||||
Consistently(func() bool { return secondSession.closed }).Should(BeFalse())
|
Consistently(func() bool { return secondSession.closed }).Should(BeFalse())
|
||||||
Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337))
|
Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337))
|
||||||
Expect(negotiatedVersions).To(ContainElement(newVersion))
|
Expect(negotiatedVersions).To(ContainElement(newVersion))
|
||||||
|
@ -398,20 +399,23 @@ var _ = Describe("Client", func() {
|
||||||
_ []protocol.VersionNumber,
|
_ []protocol.VersionNumber,
|
||||||
) (packetHandler, error) {
|
) (packetHandler, error) {
|
||||||
atomic.AddUint32(&sessionCounter, 1)
|
atomic.AddUint32(&sessionCounter, 1)
|
||||||
return sess, nil
|
return &mockSession{
|
||||||
|
connectionID: connectionID,
|
||||||
|
stopRunLoop: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
go cl.establishSecureConnection()
|
go cl.dial()
|
||||||
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1))
|
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1))
|
||||||
newVersion := protocol.VersionNumber(77)
|
newVersion := protocol.VersionNumber(77)
|
||||||
Expect(newVersion).ToNot(Equal(cl.version))
|
Expect(newVersion).ToNot(Equal(cl.version))
|
||||||
Expect(config.Versions).To(ContainElement(newVersion))
|
Expect(config.Versions).To(ContainElement(newVersion))
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
|
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
|
||||||
Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2))
|
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
|
||||||
newVersion = protocol.VersionNumber(78)
|
newVersion = protocol.VersionNumber(78)
|
||||||
Expect(newVersion).ToNot(Equal(cl.version))
|
Expect(newVersion).ToNot(Equal(cl.version))
|
||||||
Expect(config.Versions).To(ContainElement(newVersion))
|
Expect(config.Versions).To(ContainElement(newVersion))
|
||||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
|
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
|
||||||
Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2))
|
Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors if no matching version is found", func() {
|
It("errors if no matching version is found", func() {
|
||||||
|
@ -482,7 +486,7 @@ var _ = Describe("Client", func() {
|
||||||
Expect(sess.closed).To(BeFalse())
|
Expect(sess.closed).To(BeFalse())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates new sessions with the right parameters", func() {
|
It("creates new GQUIC sessions with the right parameters", func() {
|
||||||
closeErr := errors.New("peer doesn't reply")
|
closeErr := errors.New("peer doesn't reply")
|
||||||
c := make(chan struct{})
|
c := make(chan struct{})
|
||||||
var cconn connection
|
var cconn connection
|
||||||
|
@ -516,12 +520,84 @@ var _ = Describe("Client", func() {
|
||||||
Eventually(c).Should(BeClosed())
|
Eventually(c).Should(BeClosed())
|
||||||
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||||
Expect(hostname).To(Equal("quic.clemente.io"))
|
Expect(hostname).To(Equal("quic.clemente.io"))
|
||||||
Expect(version).To(Equal(cl.version))
|
Expect(version).To(Equal(config.Versions[0]))
|
||||||
Expect(conf.Versions).To(Equal(config.Versions))
|
Expect(conf.Versions).To(Equal(config.Versions))
|
||||||
sess.Close(closeErr)
|
sess.Close(closeErr)
|
||||||
Eventually(dialed).Should(BeClosed())
|
Eventually(dialed).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("creates new TLS sessions with the right parameters", func() {
|
||||||
|
config.Versions = []protocol.VersionNumber{protocol.VersionTLS}
|
||||||
|
c := make(chan struct{})
|
||||||
|
var cconn connection
|
||||||
|
var hostname string
|
||||||
|
var version protocol.VersionNumber
|
||||||
|
var conf *Config
|
||||||
|
newTLSClientSession = func(
|
||||||
|
connP connection,
|
||||||
|
hostnameP string,
|
||||||
|
versionP protocol.VersionNumber,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
configP *Config,
|
||||||
|
tls handshake.MintTLS,
|
||||||
|
paramsChan <-chan handshake.TransportParameters,
|
||||||
|
_ protocol.PacketNumber,
|
||||||
|
) (packetHandler, error) {
|
||||||
|
cconn = connP
|
||||||
|
hostname = hostnameP
|
||||||
|
version = versionP
|
||||||
|
conf = configP
|
||||||
|
close(c)
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
dialed := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||||
|
close(dialed)
|
||||||
|
}()
|
||||||
|
Eventually(c).Should(BeClosed())
|
||||||
|
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||||
|
Expect(hostname).To(Equal("quic.clemente.io"))
|
||||||
|
Expect(version).To(Equal(config.Versions[0]))
|
||||||
|
Expect(conf.Versions).To(Equal(config.Versions))
|
||||||
|
sess.Close(errors.New("peer doesn't reply"))
|
||||||
|
Eventually(dialed).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("creates a new session when the server performs a retry", func() {
|
||||||
|
config.Versions = []protocol.VersionNumber{protocol.VersionTLS}
|
||||||
|
sessionChan := make(chan *mockSession)
|
||||||
|
newTLSClientSession = func(
|
||||||
|
connP connection,
|
||||||
|
hostnameP string,
|
||||||
|
versionP protocol.VersionNumber,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
configP *Config,
|
||||||
|
tls handshake.MintTLS,
|
||||||
|
paramsChan <-chan handshake.TransportParameters,
|
||||||
|
_ protocol.PacketNumber,
|
||||||
|
) (packetHandler, error) {
|
||||||
|
sess := &mockSession{
|
||||||
|
stopRunLoop: make(chan struct{}),
|
||||||
|
}
|
||||||
|
sessionChan <- sess
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
dialed := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||||
|
close(dialed)
|
||||||
|
}()
|
||||||
|
var firstSession, secondSession *mockSession
|
||||||
|
Eventually(sessionChan).Should(Receive(&firstSession))
|
||||||
|
firstSession.Close(handshake.ErrCloseSessionForRetry)
|
||||||
|
Eventually(sessionChan).Should(Receive(&secondSession))
|
||||||
|
secondSession.Close(errors.New("stop test"))
|
||||||
|
Eventually(dialed).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
Context("handling packets", func() {
|
Context("handling packets", func() {
|
||||||
It("handles packets", func() {
|
It("handles packets", func() {
|
||||||
ph := wire.Header{
|
ph := wire.Header{
|
||||||
|
@ -532,7 +608,7 @@ var _ = Describe("Client", func() {
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
err := ph.Write(b, protocol.PerspectiveServer, cl.version)
|
err := ph.Write(b, protocol.PerspectiveServer, cl.version)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
packetConn.dataToRead = b.Bytes()
|
packetConn.dataToRead <- b.Bytes()
|
||||||
|
|
||||||
Expect(sess.packetCount).To(BeZero())
|
Expect(sess.packetCount).To(BeZero())
|
||||||
stoppedListening := make(chan struct{})
|
stoppedListening := make(chan struct{})
|
||||||
|
|
31
conn_test.go
31
conn_test.go
|
@ -2,7 +2,7 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import (
|
||||||
|
|
||||||
type mockPacketConn struct {
|
type mockPacketConn struct {
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
dataToRead []byte
|
dataToRead chan []byte
|
||||||
dataReadFrom net.Addr
|
dataReadFrom net.Addr
|
||||||
readErr error
|
readErr error
|
||||||
dataWritten bytes.Buffer
|
dataWritten bytes.Buffer
|
||||||
|
@ -20,23 +20,34 @@ type mockPacketConn struct {
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newMockPacketConn() *mockPacketConn {
|
||||||
|
return &mockPacketConn{
|
||||||
|
dataToRead: make(chan []byte, 1000),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||||
if c.readErr != nil {
|
if c.readErr != nil {
|
||||||
return 0, nil, c.readErr
|
return 0, nil, c.readErr
|
||||||
}
|
}
|
||||||
if c.dataToRead == nil { // block if there's no data
|
data, ok := <-c.dataToRead
|
||||||
time.Sleep(time.Hour)
|
if !ok {
|
||||||
return 0, nil, io.EOF
|
return 0, nil, errors.New("connection closed")
|
||||||
}
|
}
|
||||||
n := copy(b, c.dataToRead)
|
n := copy(b, data)
|
||||||
c.dataToRead = nil
|
|
||||||
return n, c.dataReadFrom, nil
|
return n, c.dataReadFrom, nil
|
||||||
}
|
}
|
||||||
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
c.dataWrittenTo = addr
|
c.dataWrittenTo = addr
|
||||||
return c.dataWritten.Write(b)
|
return c.dataWritten.Write(b)
|
||||||
}
|
}
|
||||||
func (c *mockPacketConn) Close() error { c.closed = true; return nil }
|
func (c *mockPacketConn) Close() error {
|
||||||
|
if !c.closed {
|
||||||
|
close(c.dataToRead)
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
|
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
|
||||||
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
|
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
|
||||||
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
|
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
|
||||||
|
@ -53,7 +64,7 @@ var _ = Describe("Connection", func() {
|
||||||
IP: net.IPv4(192, 168, 100, 200),
|
IP: net.IPv4(192, 168, 100, 200),
|
||||||
Port: 1337,
|
Port: 1337,
|
||||||
}
|
}
|
||||||
packetConn = &mockPacketConn{}
|
packetConn = newMockPacketConn()
|
||||||
c = &conn{
|
c = &conn{
|
||||||
currentAddr: addr,
|
currentAddr: addr,
|
||||||
pconn: packetConn,
|
pconn: packetConn,
|
||||||
|
@ -68,7 +79,7 @@ var _ = Describe("Connection", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("reads", func() {
|
It("reads", func() {
|
||||||
packetConn.dataToRead = []byte("foo")
|
packetConn.dataToRead <- []byte("foo")
|
||||||
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
|
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
|
||||||
p := make([]byte, 10)
|
p := make([]byte, 10)
|
||||||
n, raddr, err := c.Read(p)
|
n, raddr, err := c.Read(p)
|
||||||
|
|
|
@ -7,33 +7,33 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type cookieHandler struct {
|
type CookieHandler struct {
|
||||||
callback func(net.Addr, *Cookie) bool
|
callback func(net.Addr, *Cookie) bool
|
||||||
|
|
||||||
cookieGenerator *CookieGenerator
|
cookieGenerator *CookieGenerator
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ mint.CookieHandler = &cookieHandler{}
|
var _ mint.CookieHandler = &CookieHandler{}
|
||||||
|
|
||||||
func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) {
|
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
|
||||||
cookieGenerator, err := NewCookieGenerator()
|
cookieGenerator, err := NewCookieGenerator()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &cookieHandler{
|
return &CookieHandler{
|
||||||
callback: callback,
|
callback: callback,
|
||||||
cookieGenerator: cookieGenerator,
|
cookieGenerator: cookieGenerator,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||||
if h.callback(conn.RemoteAddr(), nil) {
|
if h.callback(conn.RemoteAddr(), nil) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||||
data, err := h.cookieGenerator.DecodeToken(token)
|
data, err := h.cookieGenerator.DecodeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||||
|
|
|
@ -2,6 +2,7 @@ package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
"github.com/bifurcation/mint"
|
||||||
|
|
||||||
|
@ -9,22 +10,37 @@ import (
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockConn struct {
|
||||||
|
remoteAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = &mockConn{}
|
||||||
|
|
||||||
|
func (c *mockConn) Read([]byte) (int, error) { panic("not implemented") }
|
||||||
|
func (c *mockConn) Write([]byte) (int, error) { panic("not implemented") }
|
||||||
|
func (c *mockConn) Close() error { panic("not implemented") }
|
||||||
|
func (c *mockConn) LocalAddr() net.Addr { panic("not implemented") }
|
||||||
|
func (c *mockConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||||
|
func (c *mockConn) SetReadDeadline(time.Time) error { panic("not implemented") }
|
||||||
|
func (c *mockConn) SetWriteDeadline(time.Time) error { panic("not implemented") }
|
||||||
|
func (c *mockConn) SetDeadline(time.Time) error { panic("not implemented") }
|
||||||
|
|
||||||
var callbackReturn bool
|
var callbackReturn bool
|
||||||
var mockCallback = func(net.Addr, *Cookie) bool {
|
var mockCallback = func(net.Addr, *Cookie) bool {
|
||||||
return callbackReturn
|
return callbackReturn
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ = Describe("Cookie Handler", func() {
|
var _ = Describe("Cookie Handler", func() {
|
||||||
var ch *cookieHandler
|
var ch *CookieHandler
|
||||||
var conn *mint.Conn
|
var conn *mint.Conn
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
callbackReturn = false
|
callbackReturn = false
|
||||||
var err error
|
var err error
|
||||||
ch, err = newCookieHandler(mockCallback)
|
ch, err = NewCookieHandler(mockCallback)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46}
|
addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46}
|
||||||
conn = mint.NewConn(&fakeConn{remoteAddr: addr}, &mint.Config{}, false)
|
conn = mint.NewConn(&mockConn{remoteAddr: addr}, &mint.Config{}, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("generates and validates a token", func() {
|
It("generates and validates a token", func() {
|
||||||
|
|
|
@ -381,10 +381,6 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
||||||
h.divNonceChan <- data
|
h.divNonceChan <- data
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType {
|
|
||||||
panic("not needed for cryptoSetupServer")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupClient) sendCHLO() error {
|
func (h *cryptoSetupClient) sendCHLO() error {
|
||||||
h.clientHelloCounter++
|
h.clientHelloCounter++
|
||||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||||
|
|
|
@ -458,10 +458,6 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
|
||||||
panic("not needed for cryptoSetupServer")
|
panic("not needed for cryptoSetupServer")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
|
|
||||||
panic("not needed for cryptoSetupServer")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
||||||
if len(nonce) != 32 {
|
if len(nonce) != 32 {
|
||||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
|
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
"github.com/bifurcation/mint"
|
||||||
|
@ -12,6 +11,9 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
|
||||||
|
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
|
||||||
|
|
||||||
// KeyDerivationFunction is used for key derivation
|
// KeyDerivationFunction is used for key derivation
|
||||||
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
||||||
|
|
||||||
|
@ -20,68 +22,31 @@ type cryptoSetupTLS struct {
|
||||||
|
|
||||||
perspective protocol.Perspective
|
perspective protocol.Perspective
|
||||||
|
|
||||||
tls mintTLS
|
|
||||||
conn *fakeConn
|
|
||||||
|
|
||||||
nextPacketType protocol.PacketType
|
|
||||||
|
|
||||||
keyDerivation KeyDerivationFunction
|
keyDerivation KeyDerivationFunction
|
||||||
nullAEAD crypto.AEAD
|
nullAEAD crypto.AEAD
|
||||||
aead crypto.AEAD
|
aead crypto.AEAD
|
||||||
|
|
||||||
aeadChanged chan<- protocol.EncryptionLevel
|
tls MintTLS
|
||||||
|
cryptoStream *CryptoStreamConn
|
||||||
|
aeadChanged chan<- protocol.EncryptionLevel
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||||
func NewCryptoSetupTLSServer(
|
func NewCryptoSetupTLSServer(
|
||||||
cryptoStream io.ReadWriter,
|
tls MintTLS,
|
||||||
connID protocol.ConnectionID,
|
cryptoStream *CryptoStreamConn,
|
||||||
tlsConfig *tls.Config,
|
nullAEAD crypto.AEAD,
|
||||||
remoteAddr net.Addr,
|
|
||||||
params *TransportParameters,
|
|
||||||
paramsChan chan<- TransportParameters,
|
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
aeadChanged chan<- protocol.EncryptionLevel,
|
||||||
checkCookie func(net.Addr, *Cookie) bool,
|
|
||||||
supportedVersions []protocol.VersionNumber,
|
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (CryptoSetup, error) {
|
) CryptoSetup {
|
||||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mintConf.RequireCookie = true
|
|
||||||
mintConf.CookieHandler, err = newCookieHandler(checkCookie)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mintConf.CookieProtector, err = mint.NewDefaultCookieProtector()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn := &fakeConn{
|
|
||||||
stream: cryptoStream,
|
|
||||||
pers: protocol.PerspectiveServer,
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
}
|
|
||||||
mintConn := mint.Server(conn, mintConf)
|
|
||||||
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
|
|
||||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
perspective: protocol.PerspectiveServer,
|
tls: tls,
|
||||||
tls: &mintController{mintConn},
|
cryptoStream: cryptoStream,
|
||||||
conn: conn,
|
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
|
perspective: protocol.PerspectiveServer,
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
aeadChanged: aeadChanged,
|
aeadChanged: aeadChanged,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
||||||
|
@ -89,60 +54,44 @@ func NewCryptoSetupTLSClient(
|
||||||
cryptoStream io.ReadWriter,
|
cryptoStream io.ReadWriter,
|
||||||
connID protocol.ConnectionID,
|
connID protocol.ConnectionID,
|
||||||
hostname string,
|
hostname string,
|
||||||
tlsConfig *tls.Config,
|
|
||||||
params *TransportParameters,
|
|
||||||
paramsChan chan<- TransportParameters,
|
|
||||||
aeadChanged chan<- protocol.EncryptionLevel,
|
aeadChanged chan<- protocol.EncryptionLevel,
|
||||||
initialVersion protocol.VersionNumber,
|
tls MintTLS,
|
||||||
supportedVersions []protocol.VersionNumber,
|
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) (CryptoSetup, error) {
|
) (CryptoSetup, error) {
|
||||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mintConf.ServerName = hostname
|
|
||||||
conn := &fakeConn{
|
|
||||||
stream: cryptoStream,
|
|
||||||
pers: protocol.PerspectiveClient,
|
|
||||||
}
|
|
||||||
mintConn := mint.Client(conn, mintConf)
|
|
||||||
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
|
|
||||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
conn: conn,
|
perspective: protocol.PerspectiveClient,
|
||||||
perspective: protocol.PerspectiveClient,
|
tls: tls,
|
||||||
tls: &mintController{mintConn},
|
nullAEAD: nullAEAD,
|
||||||
nullAEAD: nullAEAD,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
aeadChanged: aeadChanged,
|
||||||
aeadChanged: aeadChanged,
|
|
||||||
nextPacketType: protocol.PacketTypeInitial,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||||
|
if h.perspective == protocol.PerspectiveServer {
|
||||||
|
// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
|
||||||
|
// send out that data now
|
||||||
|
if _, err := h.cryptoStream.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
handshakeLoop:
|
handshakeLoop:
|
||||||
for {
|
for {
|
||||||
switch alert := h.tls.Handshake(); alert {
|
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
||||||
case mint.AlertStatelessRetry:
|
|
||||||
case mint.AlertNoAlert: // handshake complete
|
|
||||||
break handshakeLoop
|
|
||||||
case mint.AlertWouldBlock:
|
|
||||||
h.determineNextPacketType()
|
|
||||||
if err := h.conn.Continue(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||||
}
|
}
|
||||||
|
switch h.tls.State() {
|
||||||
|
case mint.StateClientStart: // this happens if a stateless retry is performed
|
||||||
|
return ErrCloseSessionForRetry
|
||||||
|
case mint.StateClientConnected, mint.StateServerConnected:
|
||||||
|
break handshakeLoop
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
aead, err := h.keyDerivation(h.tls, h.perspective)
|
aead, err := h.keyDerivation(h.tls, h.perspective)
|
||||||
|
@ -209,35 +158,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
|
||||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) determineNextPacketType() error {
|
|
||||||
h.mutex.Lock()
|
|
||||||
defer h.mutex.Unlock()
|
|
||||||
state := h.tls.State().HandshakeState
|
|
||||||
if h.perspective == protocol.PerspectiveServer {
|
|
||||||
switch state {
|
|
||||||
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest
|
|
||||||
h.nextPacketType = protocol.PacketTypeRetry
|
|
||||||
case "ServerStateWaitFinished":
|
|
||||||
h.nextPacketType = protocol.PacketTypeHandshake
|
|
||||||
default:
|
|
||||||
// TODO: accept 0-RTT data
|
|
||||||
return fmt.Errorf("Unexpected handshake state: %s", state)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// client
|
|
||||||
if state != "ClientStateWaitSH" {
|
|
||||||
h.nextPacketType = protocol.PacketTypeHandshake
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
|
|
||||||
h.mutex.RLock()
|
|
||||||
defer h.mutex.RUnlock()
|
|
||||||
return h.nextPacketType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
|
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
|
||||||
panic("diversification nonce not needed for TLS")
|
panic("diversification nonce not needed for TLS")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -10,7 +9,6 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
|
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
|
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
@ -23,52 +21,33 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
|
||||||
var _ = Describe("TLS Crypto Setup", func() {
|
var _ = Describe("TLS Crypto Setup", func() {
|
||||||
var (
|
var (
|
||||||
cs *cryptoSetupTLS
|
cs *cryptoSetupTLS
|
||||||
paramsChan chan TransportParameters
|
|
||||||
aeadChanged chan protocol.EncryptionLevel
|
aeadChanged chan protocol.EncryptionLevel
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
paramsChan = make(chan TransportParameters)
|
|
||||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||||
csInt, err := NewCryptoSetupTLSServer(
|
cs = NewCryptoSetupTLSServer(
|
||||||
nil,
|
nil,
|
||||||
1,
|
NewCryptoStreamConn(nil),
|
||||||
testdata.GetTLSConfig(),
|
nil, // AEAD
|
||||||
nil,
|
|
||||||
&TransportParameters{},
|
|
||||||
paramsChan,
|
|
||||||
aeadChanged,
|
aeadChanged,
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
protocol.VersionTLS,
|
protocol.VersionTLS,
|
||||||
)
|
).(*cryptoSetupTLS)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
cs = csInt.(*cryptoSetupTLS)
|
|
||||||
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors when the handshake fails", func() {
|
It("errors when the handshake fails", func() {
|
||||||
alert := mint.AlertBadRecordMAC
|
alert := mint.AlertBadRecordMAC
|
||||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(alert)
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(alert)
|
||||||
err := cs.HandleCryptoStream()
|
err := cs.HandleCryptoStream()
|
||||||
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
|
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("continues shaking hands when mint says that it would block", func() {
|
|
||||||
cs.conn.stream = &bytes.Buffer{}
|
|
||||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
|
||||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock)
|
|
||||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{})
|
|
||||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
|
||||||
cs.keyDerivation = mockKeyDerivation
|
|
||||||
err := cs.HandleCryptoStream()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("derives keys", func() {
|
It("derives keys", func() {
|
||||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||||
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
|
||||||
cs.keyDerivation = mockKeyDerivation
|
cs.keyDerivation = mockKeyDerivation
|
||||||
err := cs.HandleCryptoStream()
|
err := cs.HandleCryptoStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -76,64 +55,22 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||||
Expect(aeadChanged).To(BeClosed())
|
Expect(aeadChanged).To(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("determining the packet type", func() {
|
It("handshakes until it is connected", func() {
|
||||||
Context("for the client", func() {
|
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||||
var csClient *cryptoSetupTLS
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10)
|
||||||
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerNegotiated).Times(9)
|
||||||
BeforeEach(func() {
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
|
||||||
csInt, err := NewCryptoSetupTLSClient(
|
cs.keyDerivation = mockKeyDerivation
|
||||||
nil,
|
err := cs.HandleCryptoStream()
|
||||||
1,
|
Expect(err).ToNot(HaveOccurred())
|
||||||
"quic.clemente.io",
|
Expect(aeadChanged).To(Receive())
|
||||||
testdata.GetTLSConfig(),
|
|
||||||
&TransportParameters{},
|
|
||||||
paramsChan,
|
|
||||||
aeadChanged,
|
|
||||||
protocol.VersionTLS,
|
|
||||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
|
||||||
protocol.VersionTLS,
|
|
||||||
)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
csClient = csInt.(*cryptoSetupTLS)
|
|
||||||
csClient.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends a Client Initial first", func() {
|
|
||||||
Expect(csClient.GetNextPacketType()).To(Equal(protocol.PacketTypeInitial))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends a Handshake packet after the server sent a Server Hello", func() {
|
|
||||||
csClient.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ClientStateWaitEE"})
|
|
||||||
err := csClient.determineNextPacketType()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("for the server", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends a Stateless Retry packet", func() {
|
|
||||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateStart"})
|
|
||||||
err := cs.determineNextPacketType()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeRetry))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("sends Handshake packet", func() {
|
|
||||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateWaitFinished"})
|
|
||||||
err := cs.determineNextPacketType()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeHandshake))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("escalating crypto", func() {
|
Context("escalating crypto", func() {
|
||||||
doHandshake := func() {
|
doHandshake := func() {
|
||||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||||
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
|
||||||
cs.keyDerivation = mockKeyDerivation
|
cs.keyDerivation = mockKeyDerivation
|
||||||
err := cs.HandleCryptoStream()
|
err := cs.HandleCryptoStream()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -240,3 +177,33 @@ var _ = Describe("TLS Crypto Setup", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var _ = Describe("TLS Crypto Setup, for the client", func() {
|
||||||
|
var (
|
||||||
|
cs *cryptoSetupTLS
|
||||||
|
aeadChanged chan protocol.EncryptionLevel
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||||
|
csInt, err := NewCryptoSetupTLSClient(
|
||||||
|
nil,
|
||||||
|
0,
|
||||||
|
"quic.clemente.io",
|
||||||
|
aeadChanged,
|
||||||
|
nil, // mintTLS
|
||||||
|
protocol.VersionTLS,
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
cs = csInt.(*cryptoSetupTLS)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns when a retry is performed", func() {
|
||||||
|
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||||
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||||
|
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateClientStart)
|
||||||
|
err := cs.HandleCryptoStream()
|
||||||
|
Expect(err).To(MatchError(ErrCloseSessionForRetry))
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
|
101
internal/handshake/crypto_stream_conn.go
Normal file
101
internal/handshake/crypto_stream_conn.go
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The CryptoStreamConn is used as the net.Conn passed to mint.
|
||||||
|
// It has two operating modes:
|
||||||
|
// 1. It can read and write to bytes.Buffers.
|
||||||
|
// 2. It can use a quic.Stream for reading and writing.
|
||||||
|
// The buffer-mode is only used by the server, in order to statelessly handle retries.
|
||||||
|
type CryptoStreamConn struct {
|
||||||
|
remoteAddr net.Addr
|
||||||
|
|
||||||
|
// the buffers are used before the session is initialized
|
||||||
|
readBuf bytes.Buffer
|
||||||
|
writeBuf bytes.Buffer
|
||||||
|
|
||||||
|
// stream will be set once the session is initialized
|
||||||
|
stream io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = &CryptoStreamConn{}
|
||||||
|
|
||||||
|
// NewCryptoStreamConn creates a new CryptoStreamConn
|
||||||
|
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
|
||||||
|
return &CryptoStreamConn{remoteAddr: remoteAddr}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
|
||||||
|
if c.stream != nil {
|
||||||
|
return c.stream.Read(b)
|
||||||
|
}
|
||||||
|
return c.readBuf.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDataForReading adds data to the read buffer.
|
||||||
|
// This data will ONLY be read when the stream has not been set.
|
||||||
|
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
|
||||||
|
c.readBuf.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
|
||||||
|
if c.stream != nil {
|
||||||
|
return c.stream.Write(p)
|
||||||
|
}
|
||||||
|
return c.writeBuf.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
|
||||||
|
func (c *CryptoStreamConn) GetDataForWriting() []byte {
|
||||||
|
defer c.writeBuf.Reset()
|
||||||
|
data := make([]byte, c.writeBuf.Len())
|
||||||
|
copy(data, c.writeBuf.Bytes())
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStream sets the stream.
|
||||||
|
// After setting the stream, the read and write buffer won't be used any more.
|
||||||
|
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
|
||||||
|
c.stream = stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush copies the contents of the write buffer to the stream
|
||||||
|
func (c *CryptoStreamConn) Flush() (int, error) {
|
||||||
|
n, err := io.Copy(c.stream, &c.writeBuf)
|
||||||
|
return int(n), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close is not implemented
|
||||||
|
func (c *CryptoStreamConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr is not implemented
|
||||||
|
func (c *CryptoStreamConn) LocalAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote address
|
||||||
|
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
|
||||||
|
return c.remoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline is not implemented
|
||||||
|
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline is not implemented
|
||||||
|
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline is not implemented
|
||||||
|
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
67
internal/handshake/crypto_stream_conn_test.go
Normal file
67
internal/handshake/crypto_stream_conn_test.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("CryptoStreamConn", func() {
|
||||||
|
var (
|
||||||
|
csc *CryptoStreamConn
|
||||||
|
remoteAddr net.Addr
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||||
|
csc = NewCryptoStreamConn(remoteAddr)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads from the read buffer, when no stream is set", func() {
|
||||||
|
csc.AddDataForReading([]byte("foobar"))
|
||||||
|
data := make([]byte, 4)
|
||||||
|
n, err := csc.Read(data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(4))
|
||||||
|
Expect(data).To(Equal([]byte("foob")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes to the write buffer, when no stream is set", func() {
|
||||||
|
csc.Write([]byte("foo"))
|
||||||
|
Expect(csc.GetDataForWriting()).To(Equal([]byte("foo")))
|
||||||
|
csc.Write([]byte("bar"))
|
||||||
|
Expect(csc.GetDataForWriting()).To(Equal([]byte("bar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads from the stream, if available", func() {
|
||||||
|
csc.stream = &bytes.Buffer{}
|
||||||
|
csc.stream.Write([]byte("foobar"))
|
||||||
|
data := make([]byte, 3)
|
||||||
|
n, err := csc.Read(data)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(3))
|
||||||
|
Expect(data).To(Equal([]byte("foo")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("writes to the stream, if available", func() {
|
||||||
|
stream := &bytes.Buffer{}
|
||||||
|
csc.SetStream(stream)
|
||||||
|
csc.Write([]byte("foobar"))
|
||||||
|
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the remote address", func() {
|
||||||
|
Expect(csc.RemoteAddr()).To(Equal(remoteAddr))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("has unimplemented methods", func() {
|
||||||
|
Expect(csc.Close()).ToNot(HaveOccurred())
|
||||||
|
Expect(csc.SetDeadline(time.Time{})).ToNot(HaveOccurred())
|
||||||
|
Expect(csc.SetReadDeadline(time.Time{})).ToNot(HaveOccurred())
|
||||||
|
Expect(csc.SetWriteDeadline(time.Time{})).ToNot(HaveOccurred())
|
||||||
|
Expect(csc.LocalAddr()).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
|
@ -1,6 +1,10 @@
|
||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,14 +14,33 @@ type Sealer interface {
|
||||||
Overhead() int
|
Overhead() int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A TLSExtensionHandler sends and received the QUIC TLS extension.
|
||||||
|
// It provides the parameters sent by the peer on a channel.
|
||||||
|
type TLSExtensionHandler interface {
|
||||||
|
Send(mint.HandshakeType, *mint.ExtensionList) error
|
||||||
|
Receive(mint.HandshakeType, *mint.ExtensionList) error
|
||||||
|
GetPeerParams() <-chan TransportParameters
|
||||||
|
}
|
||||||
|
|
||||||
|
// MintTLS combines some methods needed to interact with mint.
|
||||||
|
type MintTLS interface {
|
||||||
|
crypto.TLSExporter
|
||||||
|
|
||||||
|
// additional methods
|
||||||
|
Handshake() mint.Alert
|
||||||
|
State() mint.State
|
||||||
|
|
||||||
|
SetCryptoStream(io.ReadWriter)
|
||||||
|
SetExtensionHandler(mint.AppExtensionHandler) error
|
||||||
|
}
|
||||||
|
|
||||||
// CryptoSetup is a crypto setup
|
// CryptoSetup is a crypto setup
|
||||||
type CryptoSetup interface {
|
type CryptoSetup interface {
|
||||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||||
HandleCryptoStream() error
|
HandleCryptoStream() error
|
||||||
// TODO: clean up this interface
|
// TODO: clean up this interface
|
||||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||||
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
|
|
||||||
|
|
||||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||||
|
|
|
@ -1,127 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
gocrypto "crypto"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bifurcation/mint"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
)
|
|
||||||
|
|
||||||
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
|
|
||||||
mconf := &mint.Config{
|
|
||||||
NonBlocking: true,
|
|
||||||
CipherSuites: []mint.CipherSuite{
|
|
||||||
mint.TLS_AES_128_GCM_SHA256,
|
|
||||||
mint.TLS_AES_256_GCM_SHA384,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if tlsConf != nil {
|
|
||||||
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
|
|
||||||
for i, certChain := range tlsConf.Certificates {
|
|
||||||
mconf.Certificates[i] = &mint.Certificate{
|
|
||||||
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
|
|
||||||
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
|
|
||||||
}
|
|
||||||
for j, cert := range certChain.Certificate {
|
|
||||||
c, err := x509.ParseCertificate(cert)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mconf.Certificates[i].Chain[j] = c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return mconf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type mintTLS interface {
|
|
||||||
// These two methods are the same as the crypto.TLSExporter interface.
|
|
||||||
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
|
|
||||||
GetCipherSuite() mint.CipherSuiteParams
|
|
||||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
|
||||||
// additional methods
|
|
||||||
Handshake() mint.Alert
|
|
||||||
State() mint.ConnectionState
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ crypto.TLSExporter = (mintTLS)(nil)
|
|
||||||
|
|
||||||
type mintController struct {
|
|
||||||
conn *mint.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ mintTLS = &mintController{}
|
|
||||||
|
|
||||||
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
|
|
||||||
return mc.conn.State().CipherSuite
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
|
||||||
return mc.conn.ComputeExporter(label, context, keyLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mintController) Handshake() mint.Alert {
|
|
||||||
return mc.conn.Handshake()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mintController) State() mint.ConnectionState {
|
|
||||||
return mc.conn.State()
|
|
||||||
}
|
|
||||||
|
|
||||||
// mint expects a net.Conn, but we're doing the handshake on a stream
|
|
||||||
// so we wrap a stream such that implements a net.Conn
|
|
||||||
type fakeConn struct {
|
|
||||||
stream io.ReadWriter
|
|
||||||
pers protocol.Perspective
|
|
||||||
remoteAddr net.Addr
|
|
||||||
|
|
||||||
blockRead bool
|
|
||||||
writeBuffer bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ net.Conn = &fakeConn{}
|
|
||||||
|
|
||||||
func (c *fakeConn) Read(b []byte) (int, error) {
|
|
||||||
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
c.blockRead = true // block the next Read call
|
|
||||||
return c.stream.Read(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *fakeConn) Write(p []byte) (int, error) {
|
|
||||||
if c.pers == protocol.PerspectiveClient {
|
|
||||||
return c.stream.Write(p)
|
|
||||||
}
|
|
||||||
// Buffer all writes by the server.
|
|
||||||
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
|
|
||||||
return c.writeBuffer.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *fakeConn) Continue() error {
|
|
||||||
c.blockRead = false
|
|
||||||
if c.pers == protocol.PerspectiveClient {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// write all contents of the write buffer to the stream.
|
|
||||||
_, err := c.stream.Write(c.writeBuffer.Bytes())
|
|
||||||
c.writeBuffer.Reset()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *fakeConn) Close() error { return nil }
|
|
||||||
func (c *fakeConn) LocalAddr() net.Addr { return nil }
|
|
||||||
func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
|
||||||
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
|
|
||||||
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
|
|
||||||
func (c *fakeConn) SetDeadline(time.Time) error { return nil }
|
|
|
@ -1,72 +0,0 @@
|
||||||
package handshake
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Fake Conn", func() {
|
|
||||||
var (
|
|
||||||
c *fakeConn
|
|
||||||
stream *bytes.Buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
|
||||||
stream = &bytes.Buffer{}
|
|
||||||
c = &fakeConn{stream: stream}
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Reading", func() {
|
|
||||||
It("doesn't return any new data after one Read call", func() {
|
|
||||||
stream.Write([]byte("foobar"))
|
|
||||||
b := make([]byte, 3)
|
|
||||||
_, err := c.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(b).To(Equal([]byte("foo")))
|
|
||||||
n, err := c.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(BeZero())
|
|
||||||
})
|
|
||||||
|
|
||||||
It("allows more Read calls after unblocking", func() {
|
|
||||||
stream.Write([]byte("foobar"))
|
|
||||||
b := make([]byte, 3)
|
|
||||||
_, err := c.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
err = c.Continue()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
_, err = c.Read(b)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(b).To(Equal([]byte("bar")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("Writing", func() {
|
|
||||||
It("writes directly when acting as a client", func() {
|
|
||||||
c.pers = protocol.PerspectiveClient
|
|
||||||
_, err := c.Write([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("only writes after flushing when acting as a server", func() {
|
|
||||||
c.pers = protocol.PerspectiveServer
|
|
||||||
_, err := c.Write([]byte("foobar"))
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(stream.Bytes()).To(BeEmpty())
|
|
||||||
err = c.Continue()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns its remote address", func() {
|
|
||||||
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
|
||||||
c.remoteAddr = addr
|
|
||||||
Expect(c.RemoteAddr()).To(Equal(addr))
|
|
||||||
})
|
|
||||||
})
|
|
|
@ -13,8 +13,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type extensionHandlerClient struct {
|
type extensionHandlerClient struct {
|
||||||
params *TransportParameters
|
ourParams *TransportParameters
|
||||||
paramsChan chan<- TransportParameters
|
paramsChan chan TransportParameters
|
||||||
|
|
||||||
initialVersion protocol.VersionNumber
|
initialVersion protocol.VersionNumber
|
||||||
supportedVersions []protocol.VersionNumber
|
supportedVersions []protocol.VersionNumber
|
||||||
|
@ -22,16 +22,17 @@ type extensionHandlerClient struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
|
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
|
||||||
|
var _ TLSExtensionHandler = &extensionHandlerClient{}
|
||||||
|
|
||||||
func newExtensionHandlerClient(
|
func NewExtensionHandlerClient(
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
paramsChan chan<- TransportParameters,
|
|
||||||
initialVersion protocol.VersionNumber,
|
initialVersion protocol.VersionNumber,
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) *extensionHandlerClient {
|
) TLSExtensionHandler {
|
||||||
|
paramsChan := make(chan TransportParameters, 1)
|
||||||
return &extensionHandlerClient{
|
return &extensionHandlerClient{
|
||||||
params: params,
|
ourParams: params,
|
||||||
paramsChan: paramsChan,
|
paramsChan: paramsChan,
|
||||||
initialVersion: initialVersion,
|
initialVersion: initialVersion,
|
||||||
supportedVersions: supportedVersions,
|
supportedVersions: supportedVersions,
|
||||||
|
@ -46,7 +47,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
|
||||||
|
|
||||||
data, err := syntax.Marshal(clientHelloTransportParameters{
|
data, err := syntax.Marshal(clientHelloTransportParameters{
|
||||||
InitialVersion: uint32(h.initialVersion),
|
InitialVersion: uint32(h.initialVersion),
|
||||||
Parameters: h.params.getTransportParameters(),
|
Parameters: h.ourParams.getTransportParameters(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -123,3 +124,7 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
|
||||||
h.paramsChan <- *params
|
h.paramsChan <- *params
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
|
||||||
|
return h.paramsChan
|
||||||
|
}
|
||||||
|
|
|
@ -13,15 +13,12 @@ import (
|
||||||
|
|
||||||
var _ = Describe("TLS Extension Handler, for the client", func() {
|
var _ = Describe("TLS Extension Handler, for the client", func() {
|
||||||
var (
|
var (
|
||||||
handler *extensionHandlerClient
|
handler *extensionHandlerClient
|
||||||
el mint.ExtensionList
|
el mint.ExtensionList
|
||||||
paramsChan chan TransportParameters
|
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
|
handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever).(*extensionHandlerClient)
|
||||||
paramsChan = make(chan TransportParameters, 1)
|
|
||||||
handler = newExtensionHandlerClient(&TransportParameters{}, paramsChan, protocol.VersionWhatever, nil, protocol.VersionWhatever)
|
|
||||||
el = make(mint.ExtensionList, 0)
|
el = make(mint.ExtensionList, 0)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -81,7 +78,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
|
||||||
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var params TransportParameters
|
var params TransportParameters
|
||||||
Expect(paramsChan).To(Receive(¶ms))
|
Expect(handler.GetPeerParams()).To(Receive(¶ms))
|
||||||
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
|
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -14,26 +14,27 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type extensionHandlerServer struct {
|
type extensionHandlerServer struct {
|
||||||
params *TransportParameters
|
ourParams *TransportParameters
|
||||||
paramsChan chan<- TransportParameters
|
paramsChan chan TransportParameters
|
||||||
|
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
supportedVersions []protocol.VersionNumber
|
supportedVersions []protocol.VersionNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
|
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
|
||||||
|
var _ TLSExtensionHandler = &extensionHandlerServer{}
|
||||||
|
|
||||||
func newExtensionHandlerServer(
|
func NewExtensionHandlerServer(
|
||||||
params *TransportParameters,
|
params *TransportParameters,
|
||||||
paramsChan chan<- TransportParameters,
|
|
||||||
supportedVersions []protocol.VersionNumber,
|
supportedVersions []protocol.VersionNumber,
|
||||||
version protocol.VersionNumber,
|
version protocol.VersionNumber,
|
||||||
) *extensionHandlerServer {
|
) TLSExtensionHandler {
|
||||||
|
paramsChan := make(chan TransportParameters, 1)
|
||||||
return &extensionHandlerServer{
|
return &extensionHandlerServer{
|
||||||
params: params,
|
ourParams: params,
|
||||||
paramsChan: paramsChan,
|
paramsChan: paramsChan,
|
||||||
version: version,
|
|
||||||
supportedVersions: supportedVersions,
|
supportedVersions: supportedVersions,
|
||||||
|
version: version,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +44,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
|
||||||
}
|
}
|
||||||
|
|
||||||
transportParams := append(
|
transportParams := append(
|
||||||
h.params.getTransportParameters(),
|
h.ourParams.getTransportParameters(),
|
||||||
// TODO(#855): generate a real token
|
// TODO(#855): generate a real token
|
||||||
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
|
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
|
||||||
)
|
)
|
||||||
|
@ -105,3 +106,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
|
||||||
h.paramsChan <- *params
|
h.paramsChan <- *params
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
|
||||||
|
return h.paramsChan
|
||||||
|
}
|
||||||
|
|
|
@ -20,15 +20,12 @@ func parameterMapToList(paramMap map[transportParameterID][]byte) []transportPar
|
||||||
|
|
||||||
var _ = Describe("TLS Extension Handler, for the server", func() {
|
var _ = Describe("TLS Extension Handler, for the server", func() {
|
||||||
var (
|
var (
|
||||||
handler *extensionHandlerServer
|
handler *extensionHandlerServer
|
||||||
el mint.ExtensionList
|
el mint.ExtensionList
|
||||||
paramsChan chan TransportParameters
|
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
|
handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever).(*extensionHandlerServer)
|
||||||
paramsChan = make(chan TransportParameters, 1)
|
|
||||||
handler = newExtensionHandlerServer(&TransportParameters{}, paramsChan, nil, protocol.VersionWhatever)
|
|
||||||
el = make(mint.ExtensionList, 0)
|
el = make(mint.ExtensionList, 0)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -91,7 +88,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
|
||||||
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
|
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
var params TransportParameters
|
var params TransportParameters
|
||||||
Expect(paramsChan).To(Receive(¶ms))
|
Expect(handler.GetPeerParams()).To(Receive(¶ms))
|
||||||
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
|
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
//go:generate sh -c "mockgen -source=../handshake/mint_utils.go -package mockhandshake -destination handshake/mint_tls.go"
|
//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS"
|
||||||
|
//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
|
||||||
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||||
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
|
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
|
||||||
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"
|
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"
|
||||||
|
|
|
@ -1,83 +1,106 @@
|
||||||
// Code generated by MockGen. DO NOT EDIT.
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
// Source: ../handshake/mint_utils.go
|
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS)
|
||||||
|
|
||||||
package mockhandshake
|
package mockhandshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
io "io"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
mint "github.com/bifurcation/mint"
|
mint "github.com/bifurcation/mint"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockmintTLS is a mock of mintTLS interface
|
// MockMintTLS is a mock of MintTLS interface
|
||||||
type MockmintTLS struct {
|
type MockMintTLS struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
recorder *MockmintTLSMockRecorder
|
recorder *MockMintTLSMockRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockmintTLSMockRecorder is the mock recorder for MockmintTLS
|
// MockMintTLSMockRecorder is the mock recorder for MockMintTLS
|
||||||
type MockmintTLSMockRecorder struct {
|
type MockMintTLSMockRecorder struct {
|
||||||
mock *MockmintTLS
|
mock *MockMintTLS
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMockmintTLS creates a new mock instance
|
// NewMockMintTLS creates a new mock instance
|
||||||
func NewMockmintTLS(ctrl *gomock.Controller) *MockmintTLS {
|
func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS {
|
||||||
mock := &MockmintTLS{ctrl: ctrl}
|
mock := &MockMintTLS{ctrl: ctrl}
|
||||||
mock.recorder = &MockmintTLSMockRecorder{mock}
|
mock.recorder = &MockMintTLSMockRecorder{mock}
|
||||||
return mock
|
return mock
|
||||||
}
|
}
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
func (_m *MockmintTLS) EXPECT() *MockmintTLSMockRecorder {
|
func (_m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder {
|
||||||
return _m.recorder
|
return _m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCipherSuite mocks base method
|
|
||||||
func (_m *MockmintTLS) GetCipherSuite() mint.CipherSuiteParams {
|
|
||||||
ret := _m.ctrl.Call(_m, "GetCipherSuite")
|
|
||||||
ret0, _ := ret[0].(mint.CipherSuiteParams)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCipherSuite indicates an expected call of GetCipherSuite
|
|
||||||
func (_mr *MockmintTLSMockRecorder) GetCipherSuite() *gomock.Call {
|
|
||||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockmintTLS)(nil).GetCipherSuite))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ComputeExporter mocks base method
|
// ComputeExporter mocks base method
|
||||||
func (_m *MockmintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
func (_m *MockMintTLS) ComputeExporter(_param0 string, _param1 []byte, _param2 int) ([]byte, error) {
|
||||||
ret := _m.ctrl.Call(_m, "ComputeExporter", label, context, keyLength)
|
ret := _m.ctrl.Call(_m, "ComputeExporter", _param0, _param1, _param2)
|
||||||
ret0, _ := ret[0].([]byte)
|
ret0, _ := ret[0].([]byte)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// ComputeExporter indicates an expected call of ComputeExporter
|
// ComputeExporter indicates an expected call of ComputeExporter
|
||||||
func (_mr *MockmintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
func (_mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockmintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCipherSuite mocks base method
|
||||||
|
func (_m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams {
|
||||||
|
ret := _m.ctrl.Call(_m, "GetCipherSuite")
|
||||||
|
ret0, _ := ret[0].(mint.CipherSuiteParams)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCipherSuite indicates an expected call of GetCipherSuite
|
||||||
|
func (_mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake mocks base method
|
// Handshake mocks base method
|
||||||
func (_m *MockmintTLS) Handshake() mint.Alert {
|
func (_m *MockMintTLS) Handshake() mint.Alert {
|
||||||
ret := _m.ctrl.Call(_m, "Handshake")
|
ret := _m.ctrl.Call(_m, "Handshake")
|
||||||
ret0, _ := ret[0].(mint.Alert)
|
ret0, _ := ret[0].(mint.Alert)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake indicates an expected call of Handshake
|
// Handshake indicates an expected call of Handshake
|
||||||
func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call {
|
func (_mr *MockMintTLSMockRecorder) Handshake() *gomock.Call {
|
||||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake))
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCryptoStream mocks base method
|
||||||
|
func (_m *MockMintTLS) SetCryptoStream(_param0 io.ReadWriter) {
|
||||||
|
_m.ctrl.Call(_m, "SetCryptoStream", _param0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCryptoStream indicates an expected call of SetCryptoStream
|
||||||
|
func (_mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExtensionHandler mocks base method
|
||||||
|
func (_m *MockMintTLS) SetExtensionHandler(_param0 mint.AppExtensionHandler) error {
|
||||||
|
ret := _m.ctrl.Call(_m, "SetExtensionHandler", _param0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExtensionHandler indicates an expected call of SetExtensionHandler
|
||||||
|
func (_mr *MockMintTLSMockRecorder) SetExtensionHandler(arg0 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetExtensionHandler", reflect.TypeOf((*MockMintTLS)(nil).SetExtensionHandler), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// State mocks base method
|
// State mocks base method
|
||||||
func (_m *MockmintTLS) State() mint.ConnectionState {
|
func (_m *MockMintTLS) State() mint.State {
|
||||||
ret := _m.ctrl.Call(_m, "State")
|
ret := _m.ctrl.Call(_m, "State")
|
||||||
ret0, _ := ret[0].(mint.ConnectionState)
|
ret0, _ := ret[0].(mint.State)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// State indicates an expected call of State
|
// State indicates an expected call of State
|
||||||
func (_mr *MockmintTLSMockRecorder) State() *gomock.Call {
|
func (_mr *MockMintTLSMockRecorder) State() *gomock.Call {
|
||||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockmintTLS)(nil).State))
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State))
|
||||||
}
|
}
|
||||||
|
|
71
internal/mocks/tls_extension_handler.go
Normal file
71
internal/mocks/tls_extension_handler.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler)
|
||||||
|
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
mint "github.com/bifurcation/mint"
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface
|
||||||
|
type MockTLSExtensionHandler struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockTLSExtensionHandlerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler
|
||||||
|
type MockTLSExtensionHandlerMockRecorder struct {
|
||||||
|
mock *MockTLSExtensionHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockTLSExtensionHandler creates a new mock instance
|
||||||
|
func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler {
|
||||||
|
mock := &MockTLSExtensionHandler{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (_m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder {
|
||||||
|
return _m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerParams mocks base method
|
||||||
|
func (_m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters {
|
||||||
|
ret := _m.ctrl.Call(_m, "GetPeerParams")
|
||||||
|
ret0, _ := ret[0].(<-chan handshake.TransportParameters)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerParams indicates an expected call of GetPeerParams
|
||||||
|
func (_mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive mocks base method
|
||||||
|
func (_m *MockTLSExtensionHandler) Receive(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error {
|
||||||
|
ret := _m.ctrl.Call(_m, "Receive", _param0, _param1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive indicates an expected call of Receive
|
||||||
|
func (_mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send mocks base method
|
||||||
|
func (_m *MockTLSExtensionHandler) Send(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error {
|
||||||
|
ret := _m.ctrl.Call(_m, "Send", _param0, _param1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send indicates an expected call of Send
|
||||||
|
func (_mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1)
|
||||||
|
}
|
150
mint_utils.go
Normal file
150
mint_utils.go
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
gocrypto "crypto"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mintController struct {
|
||||||
|
csc *handshake.CryptoStreamConn
|
||||||
|
conn *mint.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ handshake.MintTLS = &mintController{}
|
||||||
|
|
||||||
|
func newMintController(
|
||||||
|
csc *handshake.CryptoStreamConn,
|
||||||
|
mconf *mint.Config,
|
||||||
|
pers protocol.Perspective,
|
||||||
|
) handshake.MintTLS {
|
||||||
|
var conn *mint.Conn
|
||||||
|
if pers == protocol.PerspectiveClient {
|
||||||
|
conn = mint.Client(csc, mconf)
|
||||||
|
} else {
|
||||||
|
conn = mint.Server(csc, mconf)
|
||||||
|
}
|
||||||
|
return &mintController{
|
||||||
|
csc: csc,
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
|
||||||
|
return mc.conn.State().CipherSuite
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||||
|
return mc.conn.ComputeExporter(label, context, keyLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mintController) Handshake() mint.Alert {
|
||||||
|
return mc.conn.Handshake()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mintController) State() mint.State {
|
||||||
|
return mc.conn.State().HandshakeState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
|
||||||
|
mc.csc.SetStream(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mintController) SetExtensionHandler(h mint.AppExtensionHandler) error {
|
||||||
|
return mc.conn.SetExtensionHandler(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
|
||||||
|
mconf := &mint.Config{
|
||||||
|
NonBlocking: true,
|
||||||
|
CipherSuites: []mint.CipherSuite{
|
||||||
|
mint.TLS_AES_128_GCM_SHA256,
|
||||||
|
mint.TLS_AES_256_GCM_SHA384,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if tlsConf != nil {
|
||||||
|
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
|
||||||
|
for i, certChain := range tlsConf.Certificates {
|
||||||
|
mconf.Certificates[i] = &mint.Certificate{
|
||||||
|
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
|
||||||
|
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
|
||||||
|
}
|
||||||
|
for j, cert := range certChain.Certificate {
|
||||||
|
c, err := x509.ParseCertificate(cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mconf.Certificates[i].Chain[j] = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mconf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets
|
||||||
|
// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0.
|
||||||
|
func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.StreamFrame, error) {
|
||||||
|
unpacker := &packetUnpacker{aead: &nullAEAD{aead}, version: version}
|
||||||
|
packet, err := unpacker.Unpack(hdr.Raw, hdr, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var frame *wire.StreamFrame
|
||||||
|
for _, f := range packet.frames {
|
||||||
|
var ok bool
|
||||||
|
frame, ok = f.(*wire.StreamFrame)
|
||||||
|
if ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if frame == nil {
|
||||||
|
return nil, errors.New("Packet doesn't contain a STREAM_FRAME")
|
||||||
|
}
|
||||||
|
// We don't need a check for the stream ID here.
|
||||||
|
// The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream.
|
||||||
|
if frame.Offset != 0 {
|
||||||
|
return nil, errors.New("received stream data with non-zero offset")
|
||||||
|
}
|
||||||
|
if utils.Debug() {
|
||||||
|
utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID)
|
||||||
|
hdr.Log()
|
||||||
|
wire.LogFrame(frame, false)
|
||||||
|
}
|
||||||
|
return frame, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// packUnencryptedPacket provides a low-overhead way to pack a packet.
|
||||||
|
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
|
||||||
|
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) {
|
||||||
|
raw := getPacketBuffer()
|
||||||
|
buffer := bytes.NewBuffer(raw)
|
||||||
|
if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payloadStartIndex := buffer.Len()
|
||||||
|
if err := sf.Write(buffer, hdr.Version); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
raw = raw[0:buffer.Len()]
|
||||||
|
_ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
|
||||||
|
raw = raw[0 : buffer.Len()+aead.Overhead()]
|
||||||
|
if utils.Debug() {
|
||||||
|
utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted)
|
||||||
|
hdr.Log()
|
||||||
|
wire.LogFrame(sf, true)
|
||||||
|
}
|
||||||
|
return raw, nil
|
||||||
|
}
|
111
mint_utils_test.go
Normal file
111
mint_utils_test.go
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Packing and unpacking Initial packets", func() {
|
||||||
|
var aead crypto.AEAD
|
||||||
|
connID := protocol.ConnectionID(0x1337)
|
||||||
|
ver := protocol.VersionTLS
|
||||||
|
hdr := &wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
Type: protocol.PacketTypeRetry,
|
||||||
|
PacketNumber: 0x42,
|
||||||
|
ConnectionID: connID,
|
||||||
|
Version: ver,
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
aead, err = crypto.NewNullAEAD(protocol.PerspectiveServer, connID, ver)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
// set hdr.Raw
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
err = hdr.Write(buf, protocol.PerspectiveServer, ver)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
hdr.Raw = buf.Bytes()
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("unpacking", func() {
|
||||||
|
packPacket := func(frames []wire.Frame) []byte {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
err := hdr.Write(buf, protocol.PerspectiveClient, ver)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
payloadStartIndex := buf.Len()
|
||||||
|
aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver)
|
||||||
|
for _, f := range frames {
|
||||||
|
err := f.Write(buf, ver)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
raw := buf.Bytes()
|
||||||
|
return aeadCl.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
|
||||||
|
}
|
||||||
|
|
||||||
|
It("unpacks a packet", func() {
|
||||||
|
f := &wire.StreamFrame{
|
||||||
|
StreamID: 0,
|
||||||
|
Data: []byte("foobar"),
|
||||||
|
}
|
||||||
|
p := packPacket([]wire.Frame{f})
|
||||||
|
frame, err := unpackInitialPacket(aead, hdr, p, ver)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(frame).To(Equal(f))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects a packet that doesn't contain a STREAM_FRAME", func() {
|
||||||
|
p := packPacket([]wire.Frame{&wire.PingFrame{}})
|
||||||
|
_, err := unpackInitialPacket(aead, hdr, p, ver)
|
||||||
|
Expect(err).To(MatchError("Packet doesn't contain a STREAM_FRAME"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects a packet that has a STREAM_FRAME for the wrong stream", func() {
|
||||||
|
f := &wire.StreamFrame{
|
||||||
|
StreamID: 42,
|
||||||
|
Data: []byte("foobar"),
|
||||||
|
}
|
||||||
|
p := packPacket([]wire.Frame{f})
|
||||||
|
_, err := unpackInitialPacket(aead, hdr, p, ver)
|
||||||
|
Expect(err).To(MatchError("UnencryptedStreamData: received unencrypted stream data on stream 42"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("rejects a packet that has a STREAM_FRAME with a non-zero offset", func() {
|
||||||
|
f := &wire.StreamFrame{
|
||||||
|
StreamID: 0,
|
||||||
|
Offset: 10,
|
||||||
|
Data: []byte("foobar"),
|
||||||
|
}
|
||||||
|
p := packPacket([]wire.Frame{f})
|
||||||
|
_, err := unpackInitialPacket(aead, hdr, p, ver)
|
||||||
|
Expect(err).To(MatchError("received stream data with non-zero offset"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("packing", func() {
|
||||||
|
var unpacker *packetUnpacker
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
unpacker = &packetUnpacker{aead: &nullAEAD{aeadCl}, version: ver}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("packs a packet", func() {
|
||||||
|
f := &wire.StreamFrame{
|
||||||
|
Data: []byte("foobar"),
|
||||||
|
FinBit: true,
|
||||||
|
}
|
||||||
|
data, err := packUnencryptedPacket(aead, hdr, f, protocol.PerspectiveServer)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
packet, err := unpacker.Unpack(hdr.Raw, hdr, data[len(hdr.Raw):])
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(packet.frames).To(Equal([]wire.Frame{f}))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
|
@ -17,9 +17,9 @@ type packetNumberGenerator struct {
|
||||||
nextToSkip protocol.PacketNumber
|
nextToSkip protocol.PacketNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPacketNumberGenerator(averagePeriod protocol.PacketNumber) *packetNumberGenerator {
|
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
|
||||||
return &packetNumberGenerator{
|
return &packetNumberGenerator{
|
||||||
next: 1,
|
next: initial,
|
||||||
averagePeriod: averagePeriod,
|
averagePeriod: averagePeriod,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,12 @@ var _ = Describe("Packet Number Generator", func() {
|
||||||
var png packetNumberGenerator
|
var png packetNumberGenerator
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
png = *newPacketNumberGenerator(100)
|
png = *newPacketNumberGenerator(1, 100)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("can be initialized to return any first packet number", func() {
|
||||||
|
png = *newPacketNumberGenerator(12345, 100)
|
||||||
|
Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345)))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("gets 1 as the first packet number", func() {
|
It("gets 1 as the first packet number", func() {
|
||||||
|
|
|
@ -32,9 +32,11 @@ type packetPacker struct {
|
||||||
ackFrame *wire.AckFrame
|
ackFrame *wire.AckFrame
|
||||||
leastUnacked protocol.PacketNumber
|
leastUnacked protocol.PacketNumber
|
||||||
omitConnectionID bool
|
omitConnectionID bool
|
||||||
|
hasSentPacket bool // has the packetPacker already sent a packet
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPacketPacker(connectionID protocol.ConnectionID,
|
func newPacketPacker(connectionID protocol.ConnectionID,
|
||||||
|
initialPacketNumber protocol.PacketNumber,
|
||||||
cryptoSetup handshake.CryptoSetup,
|
cryptoSetup handshake.CryptoSetup,
|
||||||
streamFramer *streamFramer,
|
streamFramer *streamFramer,
|
||||||
perspective protocol.Perspective,
|
perspective protocol.Perspective,
|
||||||
|
@ -46,7 +48,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
|
||||||
perspective: perspective,
|
perspective: perspective,
|
||||||
version: version,
|
version: version,
|
||||||
streamFramer: streamFramer,
|
streamFramer: streamFramer,
|
||||||
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
|
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +118,12 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
|
||||||
// PackPacket packs a new packet
|
// PackPacket packs a new packet
|
||||||
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
|
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
|
||||||
func (p *packetPacker) PackPacket() (*packedPacket, error) {
|
func (p *packetPacker) PackPacket() (*packedPacket, error) {
|
||||||
if p.streamFramer.HasCryptoStreamFrame() {
|
hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame()
|
||||||
|
// if this is the first packet to be send, make sure it contains stream data
|
||||||
|
if !p.hasSentPacket && !hasCryptoStreamFrame {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if hasCryptoStreamFrame {
|
||||||
return p.packCryptoPacket()
|
return p.packCryptoPacket()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,18 +273,21 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
|
||||||
pnum := p.packetNumberGenerator.Peek()
|
pnum := p.packetNumberGenerator.Peek()
|
||||||
packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked)
|
packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked)
|
||||||
|
|
||||||
var isLongHeader bool
|
|
||||||
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
|
|
||||||
// TODO: set the Long Header type
|
|
||||||
packetNumberLen = protocol.PacketNumberLen4
|
|
||||||
isLongHeader = true
|
|
||||||
}
|
|
||||||
|
|
||||||
header := &wire.Header{
|
header := &wire.Header{
|
||||||
ConnectionID: p.connectionID,
|
ConnectionID: p.connectionID,
|
||||||
PacketNumber: pnum,
|
PacketNumber: pnum,
|
||||||
PacketNumberLen: packetNumberLen,
|
PacketNumberLen: packetNumberLen,
|
||||||
IsLongHeader: isLongHeader,
|
}
|
||||||
|
|
||||||
|
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
|
||||||
|
header.PacketNumberLen = protocol.PacketNumberLen4
|
||||||
|
header.IsLongHeader = true
|
||||||
|
if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient {
|
||||||
|
header.Type = protocol.PacketTypeInitial
|
||||||
|
// TODO(#886): add padding
|
||||||
|
} else {
|
||||||
|
header.Type = protocol.PacketTypeHandshake
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure {
|
if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure {
|
||||||
|
@ -292,7 +302,6 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
|
||||||
header.Version = p.version
|
header.Version = p.version
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
header.Type = p.cryptoSetup.GetNextPacketType()
|
|
||||||
if encLevel != protocol.EncryptionForwardSecure {
|
if encLevel != protocol.EncryptionForwardSecure {
|
||||||
header.Version = p.version
|
header.Version = p.version
|
||||||
}
|
}
|
||||||
|
@ -330,7 +339,7 @@ func (p *packetPacker) writeAndSealPacket(
|
||||||
if num != header.PacketNumber {
|
if num != header.PacketNumber {
|
||||||
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
|
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
|
||||||
}
|
}
|
||||||
|
p.hasSentPacket = true
|
||||||
return raw, nil
|
return raw, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,6 @@ type mockCryptoSetup struct {
|
||||||
divNonce []byte
|
divNonce []byte
|
||||||
encLevelSeal protocol.EncryptionLevel
|
encLevelSeal protocol.EncryptionLevel
|
||||||
encLevelSealCrypto protocol.EncryptionLevel
|
encLevelSealCrypto protocol.EncryptionLevel
|
||||||
nextPacketType protocol.PacketType
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ handshake.CryptoSetup = &mockCryptoSetup{}
|
var _ handshake.CryptoSetup = &mockCryptoSetup{}
|
||||||
|
@ -50,7 +49,6 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
|
||||||
}
|
}
|
||||||
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
|
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
|
||||||
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
|
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
|
||||||
func (m *mockCryptoSetup) GetNextPacketType() protocol.PacketType { return m.nextPacketType }
|
|
||||||
|
|
||||||
var _ = Describe("Packet packer", func() {
|
var _ = Describe("Packet packer", func() {
|
||||||
var (
|
var (
|
||||||
|
@ -69,13 +67,14 @@ var _ = Describe("Packet packer", func() {
|
||||||
packer = &packetPacker{
|
packer = &packetPacker{
|
||||||
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
|
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
|
||||||
connectionID: 0x1337,
|
connectionID: 0x1337,
|
||||||
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
|
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
|
||||||
streamFramer: streamFramer,
|
streamFramer: streamFramer,
|
||||||
perspective: protocol.PerspectiveServer,
|
perspective: protocol.PerspectiveServer,
|
||||||
}
|
}
|
||||||
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
|
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
|
||||||
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
|
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
|
||||||
packer.version = protocol.VersionWhatever
|
packer.version = protocol.VersionWhatever
|
||||||
|
packer.hasSentPacket = true
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns nil when no packet is queued", func() {
|
It("returns nil when no packet is queued", func() {
|
||||||
|
@ -191,13 +190,6 @@ var _ = Describe("Packet packer", func() {
|
||||||
Expect(h.Version).To(Equal(versionIETFHeader))
|
Expect(h.Version).To(Equal(versionIETFHeader))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sets the packet type based on the state of the handshake", func() {
|
|
||||||
packer.cryptoSetup.(*mockCryptoSetup).nextPacketType = 5
|
|
||||||
h := packer.getHeader(protocol.EncryptionSecure)
|
|
||||||
Expect(h.IsLongHeader).To(BeTrue())
|
|
||||||
Expect(h.Type).To(Equal(protocol.PacketType(5)))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("uses the Short Header format for forward-secure packets", func() {
|
It("uses the Short Header format for forward-secure packets", func() {
|
||||||
h := packer.getHeader(protocol.EncryptionForwardSecure)
|
h := packer.getHeader(protocol.EncryptionForwardSecure)
|
||||||
Expect(h.IsLongHeader).To(BeFalse())
|
Expect(h.IsLongHeader).To(BeFalse())
|
||||||
|
@ -269,7 +261,7 @@ var _ = Describe("Packet packer", func() {
|
||||||
Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber))
|
Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("packs a StopWaitingFrame first", func() {
|
It("packs a STOP_WAITING frame first", func() {
|
||||||
packer.packetNumberGenerator.next = 15
|
packer.packetNumberGenerator.next = 15
|
||||||
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
|
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
|
||||||
packer.QueueControlFrame(&wire.RstStreamFrame{})
|
packer.QueueControlFrame(&wire.RstStreamFrame{})
|
||||||
|
@ -281,7 +273,7 @@ var _ = Describe("Packet packer", func() {
|
||||||
Expect(p.frames[0]).To(Equal(swf))
|
Expect(p.frames[0]).To(Equal(swf))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sets the LeastUnackedDelta length of a StopWaitingFrame", func() {
|
It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() {
|
||||||
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
|
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
|
||||||
packer.packetNumberGenerator.next = packetNumber
|
packer.packetNumberGenerator.next = packetNumber
|
||||||
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
|
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
|
||||||
|
@ -292,7 +284,7 @@ var _ = Describe("Packet packer", func() {
|
||||||
Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("does not pack a packet containing only a StopWaitingFrame", func() {
|
It("does not pack a packet containing only a STOP_WAITING frame", func() {
|
||||||
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
|
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
|
||||||
packer.QueueControlFrame(swf)
|
packer.QueueControlFrame(swf)
|
||||||
p, err := packer.PackPacket()
|
p, err := packer.PackPacket()
|
||||||
|
@ -307,6 +299,14 @@ var _ = Describe("Packet packer", func() {
|
||||||
Expect(p).ToNot(BeNil())
|
Expect(p).ToNot(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() {
|
||||||
|
packer.hasSentPacket = false
|
||||||
|
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
|
||||||
|
p, err := packer.PackPacket()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(p).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
It("packs many control frames into 1 packets", func() {
|
It("packs many control frames into 1 packets", func() {
|
||||||
f := &wire.AckFrame{LargestAcked: 1}
|
f := &wire.AckFrame{LargestAcked: 1}
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
|
@ -602,7 +602,7 @@ var _ = Describe("Packet packer", func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("Blocked frames", func() {
|
Context("BLOCKED frames", func() {
|
||||||
It("queues a BLOCKED frame", func() {
|
It("queues a BLOCKED frame", func() {
|
||||||
length := 100
|
length := 100
|
||||||
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
|
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
|
||||||
|
@ -750,7 +750,7 @@ var _ = Describe("Packet packer", func() {
|
||||||
Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment"))
|
Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("refuses to retransmit packets without a StopWaitingFrame", func() {
|
It("refuses to retransmit packets without a STOP_WAITING Frame", func() {
|
||||||
packer.stopWaiting = nil
|
packer.stopWaiting = nil
|
||||||
_, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{
|
_, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{
|
||||||
EncryptionLevel: protocol.EncryptionSecure,
|
EncryptionLevel: protocol.EncryptionSecure,
|
||||||
|
|
106
server.go
106
server.go
|
@ -19,6 +19,7 @@ import (
|
||||||
// packetHandler handles packets
|
// packetHandler handles packets
|
||||||
type packetHandler interface {
|
type packetHandler interface {
|
||||||
Session
|
Session
|
||||||
|
getCryptoStream() cryptoStream
|
||||||
handshakeStatus() <-chan handshakeEvent
|
handshakeStatus() <-chan handshakeEvent
|
||||||
handlePacket(*receivedPacket)
|
handlePacket(*receivedPacket)
|
||||||
GetVersion() protocol.VersionNumber
|
GetVersion() protocol.VersionNumber
|
||||||
|
@ -33,6 +34,9 @@ type server struct {
|
||||||
|
|
||||||
conn net.PacketConn
|
conn net.PacketConn
|
||||||
|
|
||||||
|
supportsTLS bool
|
||||||
|
serverTLS *serverTLS
|
||||||
|
|
||||||
certChain crypto.CertChain
|
certChain crypto.CertChain
|
||||||
scfg *handshake.ServerConfig
|
scfg *handshake.ServerConfig
|
||||||
|
|
||||||
|
@ -77,11 +81,21 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
config = populateServerConfig(config)
|
||||||
|
|
||||||
|
// check if any of the supported versions supports TLS
|
||||||
|
var supportsTLS bool
|
||||||
|
for _, v := range config.Versions {
|
||||||
|
if v.UsesTLS() {
|
||||||
|
supportsTLS = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s := &server{
|
s := &server{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
tlsConf: tlsConf,
|
tlsConf: tlsConf,
|
||||||
config: populateServerConfig(config),
|
config: config,
|
||||||
certChain: certChain,
|
certChain: certChain,
|
||||||
scfg: scfg,
|
scfg: scfg,
|
||||||
sessions: map[protocol.ConnectionID]packetHandler{},
|
sessions: map[protocol.ConnectionID]packetHandler{},
|
||||||
|
@ -89,12 +103,47 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
||||||
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
||||||
sessionQueue: make(chan Session, 5),
|
sessionQueue: make(chan Session, 5),
|
||||||
errorChan: make(chan struct{}),
|
errorChan: make(chan struct{}),
|
||||||
|
supportsTLS: supportsTLS,
|
||||||
|
}
|
||||||
|
if supportsTLS {
|
||||||
|
if err := s.setupTLS(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
go s.serve()
|
go s.serve()
|
||||||
utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) setupTLS() error {
|
||||||
|
cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.serverTLS = serverTLS
|
||||||
|
// handle TLS connection establishment statelessly
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.errorChan:
|
||||||
|
return
|
||||||
|
case sess := <-sessionChan:
|
||||||
|
// TODO: think about what to do with connection ID collisions
|
||||||
|
connID := sess.(*session).connectionID
|
||||||
|
s.sessionsMutex.Lock()
|
||||||
|
s.sessions[connID] = sess
|
||||||
|
s.sessionsMutex.Unlock()
|
||||||
|
s.runHandshakeAndSession(sess, connID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
|
var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
|
||||||
if cookie == nil {
|
if cookie == nil {
|
||||||
return false
|
return false
|
||||||
|
@ -225,8 +274,16 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
||||||
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
|
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
|
||||||
}
|
}
|
||||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||||
|
packetData := packet[len(packet)-r.Len():]
|
||||||
connID := hdr.ConnectionID
|
connID := hdr.ConnectionID
|
||||||
|
|
||||||
|
if hdr.Type == protocol.PacketTypeInitial {
|
||||||
|
if s.supportsTLS {
|
||||||
|
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
s.sessionsMutex.RLock()
|
s.sessionsMutex.RLock()
|
||||||
session, sessionKnown := s.sessions[connID]
|
session, sessionKnown := s.sessions[connID]
|
||||||
s.sessionsMutex.RUnlock()
|
s.sessionsMutex.RUnlock()
|
||||||
|
@ -279,11 +336,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header
|
|
||||||
if hdr.Type == protocol.PacketTypeInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
|
||||||
_, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.config.Versions), remoteAddr)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !sessionKnown {
|
if !sessionKnown {
|
||||||
version := hdr.Version
|
version := hdr.Version
|
||||||
|
@ -307,34 +359,38 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
||||||
s.sessions[connID] = session
|
s.sessions[connID] = session
|
||||||
s.sessionsMutex.Unlock()
|
s.sessionsMutex.Unlock()
|
||||||
|
|
||||||
go func() {
|
s.runHandshakeAndSession(session, connID)
|
||||||
// session.run() returns as soon as the session is closed
|
|
||||||
_ = session.run()
|
|
||||||
s.removeConnection(connID)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
ev := <-session.handshakeStatus()
|
|
||||||
if ev.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if ev.encLevel == protocol.EncryptionForwardSecure {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.sessionQueue <- session
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
session.handlePacket(&receivedPacket{
|
session.handlePacket(&receivedPacket{
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
header: hdr,
|
header: hdr,
|
||||||
data: packet[len(packet)-r.Len():],
|
data: packetData,
|
||||||
rcvTime: rcvTime,
|
rcvTime: rcvTime,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) {
|
||||||
|
go func() {
|
||||||
|
_ = session.run()
|
||||||
|
// session.run() returns as soon as the session is closed
|
||||||
|
s.removeConnection(connID)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
ev := <-session.handshakeStatus()
|
||||||
|
if ev.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ev.encLevel == protocol.EncryptionForwardSecure {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.sessionQueue <- session
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *server) removeConnection(id protocol.ConnectionID) {
|
func (s *server) removeConnection(id protocol.ConnectionID) {
|
||||||
s.sessionsMutex.Lock()
|
s.sessionsMutex.Lock()
|
||||||
s.sessions[id] = nil
|
s.sessions[id] = nil
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
"github.com/lucas-clemente/quic-go/qerr"
|
"github.com/lucas-clemente/quic-go/qerr"
|
||||||
|
@ -67,6 +68,7 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not imple
|
||||||
func (*mockSession) Context() context.Context { panic("not implemented") }
|
func (*mockSession) Context() context.Context { panic("not implemented") }
|
||||||
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
|
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
|
||||||
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan }
|
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan }
|
||||||
|
func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") }
|
||||||
|
|
||||||
var _ Session = &mockSession{}
|
var _ Session = &mockSession{}
|
||||||
var _ NonFWSession = &mockSession{}
|
var _ NonFWSession = &mockSession{}
|
||||||
|
@ -96,7 +98,8 @@ var _ = Describe("Server", func() {
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
conn = &mockPacketConn{addr: &net.UDPAddr{}}
|
conn = newMockPacketConn()
|
||||||
|
conn.addr = &net.UDPAddr{}
|
||||||
config = &Config{Versions: protocol.SupportedVersions}
|
config = &Config{Versions: protocol.SupportedVersions}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -235,14 +238,14 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("works if no quic.Config is given", func(done Done) {
|
It("works if no quic.Config is given", func(done Done) {
|
||||||
ln, err := ListenAddr("127.0.0.1:0", nil, config)
|
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), nil)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(ln.Close()).To(Succeed())
|
Expect(ln.Close()).To(Succeed())
|
||||||
close(done)
|
close(done)
|
||||||
}, 1)
|
}, 1)
|
||||||
|
|
||||||
It("closes properly", func() {
|
It("closes properly", func() {
|
||||||
ln, err := ListenAddr("127.0.0.1:0", nil, config)
|
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
var returned bool
|
var returned bool
|
||||||
|
@ -409,7 +412,7 @@ var _ = Describe("Server", func() {
|
||||||
}
|
}
|
||||||
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
||||||
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
||||||
conn.dataToRead = b.Bytes()
|
conn.dataToRead <- b.Bytes()
|
||||||
conn.dataReadFrom = udpAddr
|
conn.dataReadFrom = udpAddr
|
||||||
ln, err := Listen(conn, nil, config)
|
ln, err := Listen(conn, nil, config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
@ -432,7 +435,7 @@ var _ = Describe("Server", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
|
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
|
||||||
config.Versions = []protocol.VersionNumber{99}
|
config.Versions = []protocol.VersionNumber{99, protocol.VersionTLS}
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
hdr := wire.Header{
|
hdr := wire.Header{
|
||||||
Type: protocol.PacketTypeInitial,
|
Type: protocol.PacketTypeInitial,
|
||||||
|
@ -441,11 +444,12 @@ var _ = Describe("Server", func() {
|
||||||
PacketNumber: 0x55,
|
PacketNumber: 0x55,
|
||||||
Version: 0x1234,
|
Version: 0x1234,
|
||||||
}
|
}
|
||||||
hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
||||||
conn.dataToRead = b.Bytes()
|
conn.dataToRead <- b.Bytes()
|
||||||
conn.dataReadFrom = udpAddr
|
conn.dataReadFrom = udpAddr
|
||||||
ln, err := Listen(conn, nil, config)
|
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
@ -466,9 +470,32 @@ var _ = Describe("Server", func() {
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
|
||||||
|
version := protocol.VersionNumber(99)
|
||||||
|
Expect(version.UsesTLS()).To(BeFalse())
|
||||||
|
config.Versions = []protocol.VersionNumber{version}
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
hdr := wire.Header{
|
||||||
|
Type: protocol.PacketTypeInitial,
|
||||||
|
IsLongHeader: true,
|
||||||
|
ConnectionID: 0x1337,
|
||||||
|
PacketNumber: 0x55,
|
||||||
|
Version: protocol.VersionTLS,
|
||||||
|
}
|
||||||
|
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
||||||
|
conn.dataToRead <- b.Bytes()
|
||||||
|
conn.dataReadFrom = udpAddr
|
||||||
|
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||||
|
defer ln.Close()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
|
||||||
|
})
|
||||||
|
|
||||||
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
|
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
|
||||||
conn.dataReadFrom = udpAddr
|
conn.dataReadFrom = udpAddr
|
||||||
conn.dataToRead = []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
|
conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
|
||||||
ln, err := Listen(conn, nil, config)
|
ln, err := Listen(conn, nil, config)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go func() {
|
go func() {
|
||||||
|
|
179
server_tls.go
Normal file
179
server_tls.go
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
)
|
||||||
|
|
||||||
|
type nullAEAD struct {
|
||||||
|
aead crypto.AEAD
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ quicAEAD = &nullAEAD{}
|
||||||
|
|
||||||
|
func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||||
|
data, err := n.aead.Open(dst, src, packetNumber, associatedData)
|
||||||
|
return data, protocol.EncryptionUnencrypted, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverTLS struct {
|
||||||
|
conn net.PacketConn
|
||||||
|
config *Config
|
||||||
|
supportedVersions []protocol.VersionNumber
|
||||||
|
mintConf *mint.Config
|
||||||
|
cookieProtector mint.CookieProtector
|
||||||
|
params *handshake.TransportParameters
|
||||||
|
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
|
||||||
|
|
||||||
|
sessionChan chan<- packetHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerTLS(
|
||||||
|
conn net.PacketConn,
|
||||||
|
config *Config,
|
||||||
|
cookieHandler *handshake.CookieHandler,
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
) (*serverTLS, <-chan packetHandler, error) {
|
||||||
|
mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
mconf.RequireCookie = true
|
||||||
|
cs, err := mint.NewDefaultCookieProtector()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
mconf.CookieProtector = cs
|
||||||
|
mconf.CookieHandler = cookieHandler
|
||||||
|
|
||||||
|
sessionChan := make(chan packetHandler)
|
||||||
|
s := &serverTLS{
|
||||||
|
conn: conn,
|
||||||
|
config: config,
|
||||||
|
supportedVersions: config.Versions,
|
||||||
|
mintConf: mconf,
|
||||||
|
sessionChan: sessionChan,
|
||||||
|
params: &handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
MaxStreams: protocol.MaxIncomingStreams,
|
||||||
|
IdleTimeout: config.IdleTimeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.newMintConn = s.newMintConnImpl
|
||||||
|
return s, sessionChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) {
|
||||||
|
utils.Debugf("Received a Packet. Handling it statelessly.")
|
||||||
|
sess, err := s.handleInitialImpl(remoteAddr, hdr, data)
|
||||||
|
if err != nil {
|
||||||
|
utils.Errorf("Error occured handling initial packet: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if sess == nil { // a stateless reset was done
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sessionChan <- sess
|
||||||
|
}
|
||||||
|
|
||||||
|
// will be set to s.newMintConn by the constructor
|
||||||
|
func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
|
||||||
|
conn := mint.Server(bc, s.mintConf)
|
||||||
|
extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v)
|
||||||
|
if err := conn.SetExtensionHandler(extHandler); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
tls := newMintController(bc, s.mintConf, protocol.PerspectiveServer)
|
||||||
|
tls.SetExtensionHandler(extHandler)
|
||||||
|
return tls, extHandler.GetPeerParams(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
|
||||||
|
// TODO: check length requirement
|
||||||
|
// check version, if not matching send VNP
|
||||||
|
if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
|
||||||
|
utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
|
||||||
|
_, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.supportedVersions), remoteAddr)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// unpack packet and check stream frame contents
|
||||||
|
version := hdr.Version
|
||||||
|
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
frame, err := unpackInitialPacket(aead, hdr, data, version)
|
||||||
|
if err != nil {
|
||||||
|
utils.Debugf("Error unpacking initial packet: %s", err)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
bc := handshake.NewCryptoStreamConn(remoteAddr)
|
||||||
|
bc.AddDataForReading(frame.Data)
|
||||||
|
tls, paramsChan, err := s.newMintConn(bc, hdr.Version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
alert := tls.Handshake()
|
||||||
|
if alert == mint.AlertStatelessRetry {
|
||||||
|
// the HelloRetryRequest was written to the bufferConn
|
||||||
|
// Take that data and write send a Retry packet
|
||||||
|
replyHdr := &wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
Type: protocol.PacketTypeRetry,
|
||||||
|
ConnectionID: hdr.ConnectionID, // echo the client's connection ID
|
||||||
|
PacketNumber: hdr.PacketNumber, // echo the client's packet number
|
||||||
|
Version: version,
|
||||||
|
}
|
||||||
|
f := &wire.StreamFrame{
|
||||||
|
StreamID: version.CryptoStreamID(),
|
||||||
|
Data: bc.GetDataForWriting(),
|
||||||
|
}
|
||||||
|
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = s.conn.WriteTo(data, remoteAddr)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if alert != mint.AlertNoAlert {
|
||||||
|
return nil, alert
|
||||||
|
}
|
||||||
|
if tls.State() != mint.StateServerNegotiated {
|
||||||
|
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
|
||||||
|
}
|
||||||
|
if alert := tls.Handshake(); alert != mint.AlertNoAlert {
|
||||||
|
return nil, alert
|
||||||
|
}
|
||||||
|
if tls.State() != mint.StateServerWaitFlight2 {
|
||||||
|
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
|
||||||
|
}
|
||||||
|
params := <-paramsChan
|
||||||
|
sess, err := newTLSServerSession(
|
||||||
|
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
||||||
|
hdr.ConnectionID, // TODO: we can use a server-chosen connection ID here
|
||||||
|
protocol.PacketNumber(1), // TODO: use a random packet number here
|
||||||
|
s.config,
|
||||||
|
tls,
|
||||||
|
bc,
|
||||||
|
aead,
|
||||||
|
¶ms,
|
||||||
|
version,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cs := sess.getCryptoStream()
|
||||||
|
cs.SetReadOffset(frame.DataLen())
|
||||||
|
bc.SetStream(cs)
|
||||||
|
return sess, nil
|
||||||
|
}
|
116
server_tls_test.go
Normal file
116
server_tls_test.go
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
|
||||||
|
|
||||||
|
"github.com/bifurcation/mint"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Stateless TLS handling", func() {
|
||||||
|
var (
|
||||||
|
conn *mockPacketConn
|
||||||
|
server *serverTLS
|
||||||
|
sessionChan <-chan packetHandler
|
||||||
|
mintTLS *mockhandshake.MockMintTLS
|
||||||
|
extHandler *mocks.MockTLSExtensionHandler
|
||||||
|
mintReply io.Writer
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
mintTLS = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||||
|
extHandler = mocks.NewMockTLSExtensionHandler(mockCtrl)
|
||||||
|
conn = newMockPacketConn()
|
||||||
|
config := &Config{
|
||||||
|
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
|
||||||
|
mintReply = bc
|
||||||
|
return mintTLS, extHandler.GetPeerParams(), nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
getPacket := func(f wire.Frame) (*wire.Header, []byte) {
|
||||||
|
hdrBuf := &bytes.Buffer{}
|
||||||
|
hdr := &wire.Header{
|
||||||
|
IsLongHeader: true,
|
||||||
|
PacketNumber: 1,
|
||||||
|
Version: protocol.VersionTLS,
|
||||||
|
}
|
||||||
|
err := hdr.Write(hdrBuf, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
hdr.Raw = hdrBuf.Bytes()
|
||||||
|
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, 0, protocol.VersionTLS)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
err = f.Write(buf, protocol.VersionTLS)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
return hdr, aead.Seal(nil, buf.Bytes(), 1, hdr.Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
It("sends a version negotiation packet if it doesn't support the version", func() {
|
||||||
|
server.HandleInitial(nil, &wire.Header{Version: 0x1337}, nil)
|
||||||
|
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
||||||
|
hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionUnknown)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(hdr.IsVersionNegotiation).To(BeTrue())
|
||||||
|
Expect(sessionChan).ToNot(Receive())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores packets with invalid contents", func() {
|
||||||
|
hdr, data := getPacket(&wire.StreamFrame{StreamID: 10, Offset: 11, Data: []byte("foobar")})
|
||||||
|
server.HandleInitial(nil, hdr, data)
|
||||||
|
Expect(conn.dataWritten.Len()).To(BeZero())
|
||||||
|
Expect(sessionChan).ToNot(Receive())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("replies with a Retry packet, if a Cookie is required", func() {
|
||||||
|
extHandler.EXPECT().GetPeerParams()
|
||||||
|
mintTLS.EXPECT().Handshake().Return(mint.AlertStatelessRetry).Do(func() {
|
||||||
|
mintReply.Write([]byte("Retry with this Cookie"))
|
||||||
|
})
|
||||||
|
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
|
||||||
|
server.HandleInitial(nil, hdr, data)
|
||||||
|
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
||||||
|
hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionTLS)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||||
|
Expect(sessionChan).ToNot(Receive())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("replies with a Handshake packet and creates a session, if no Cookie is required", func() {
|
||||||
|
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert).Do(func() {
|
||||||
|
mintReply.Write([]byte("Server Hello"))
|
||||||
|
})
|
||||||
|
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||||
|
mintTLS.EXPECT().State().Return(mint.StateServerNegotiated)
|
||||||
|
mintTLS.EXPECT().State().Return(mint.StateServerWaitFlight2)
|
||||||
|
paramsChan := make(chan handshake.TransportParameters, 1)
|
||||||
|
paramsChan <- handshake.TransportParameters{}
|
||||||
|
extHandler.EXPECT().GetPeerParams().Return(paramsChan)
|
||||||
|
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
server.HandleInitial(nil, hdr, data)
|
||||||
|
// the Handshake packet is written by the session
|
||||||
|
Expect(conn.dataWritten.Len()).To(BeZero())
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
Eventually(sessionChan).Should(Receive())
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
})
|
||||||
|
})
|
248
session.go
248
session.go
|
@ -11,6 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||||
"github.com/lucas-clemente/quic-go/congestion"
|
"github.com/lucas-clemente/quic-go/congestion"
|
||||||
|
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||||
|
@ -60,7 +61,7 @@ type session struct {
|
||||||
conn connection
|
conn connection
|
||||||
|
|
||||||
streamsMap *streamsMap
|
streamsMap *streamsMap
|
||||||
cryptoStream streamI
|
cryptoStream cryptoStream
|
||||||
|
|
||||||
rttStats *congestion.RTTStats
|
rttStats *congestion.RTTStats
|
||||||
|
|
||||||
|
@ -126,21 +127,48 @@ func newSession(
|
||||||
conn connection,
|
conn connection,
|
||||||
v protocol.VersionNumber,
|
v protocol.VersionNumber,
|
||||||
connectionID protocol.ConnectionID,
|
connectionID protocol.ConnectionID,
|
||||||
sCfg *handshake.ServerConfig,
|
scfg *handshake.ServerConfig,
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
) (packetHandler, error) {
|
) (packetHandler, error) {
|
||||||
|
paramsChan := make(chan handshake.TransportParameters)
|
||||||
|
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||||
s := &session{
|
s := &session{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
perspective: protocol.PerspectiveServer,
|
perspective: protocol.PerspectiveServer,
|
||||||
version: v,
|
version: v,
|
||||||
config: config,
|
config: config,
|
||||||
|
aeadChanged: aeadChanged,
|
||||||
|
paramsChan: paramsChan,
|
||||||
}
|
}
|
||||||
return s, s.setup(sCfg, "", tlsConf, v, nil)
|
s.preSetup()
|
||||||
|
transportParams := &handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
MaxStreams: protocol.MaxIncomingStreams,
|
||||||
|
IdleTimeout: s.config.IdleTimeout,
|
||||||
|
}
|
||||||
|
cs, err := newCryptoSetup(
|
||||||
|
s.cryptoStream,
|
||||||
|
s.connectionID,
|
||||||
|
s.conn.RemoteAddr(),
|
||||||
|
s.version,
|
||||||
|
scfg,
|
||||||
|
transportParams,
|
||||||
|
s.config.Versions,
|
||||||
|
s.config.AcceptCookie,
|
||||||
|
paramsChan,
|
||||||
|
aeadChanged,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.cryptoSetup = cs
|
||||||
|
return s, s.postSetup(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// declare this as a variable, such that we can it mock it in the tests
|
// declare this as a variable, so that we can it mock it in the tests
|
||||||
var newClientSession = func(
|
var newClientSession = func(
|
||||||
conn connection,
|
conn connection,
|
||||||
hostname string,
|
hostname string,
|
||||||
|
@ -151,27 +179,130 @@ var newClientSession = func(
|
||||||
initialVersion protocol.VersionNumber,
|
initialVersion protocol.VersionNumber,
|
||||||
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
|
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
|
||||||
) (packetHandler, error) {
|
) (packetHandler, error) {
|
||||||
|
paramsChan := make(chan handshake.TransportParameters)
|
||||||
|
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||||
s := &session{
|
s := &session{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
connectionID: connectionID,
|
connectionID: connectionID,
|
||||||
perspective: protocol.PerspectiveClient,
|
perspective: protocol.PerspectiveClient,
|
||||||
version: v,
|
version: v,
|
||||||
config: config,
|
config: config,
|
||||||
|
aeadChanged: aeadChanged,
|
||||||
|
paramsChan: paramsChan,
|
||||||
}
|
}
|
||||||
return s, s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions)
|
s.preSetup()
|
||||||
|
transportParams := &handshake.TransportParameters{
|
||||||
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
MaxStreams: protocol.MaxIncomingStreams,
|
||||||
|
IdleTimeout: s.config.IdleTimeout,
|
||||||
|
OmitConnectionID: s.config.RequestConnectionIDOmission,
|
||||||
|
}
|
||||||
|
cs, err := newCryptoSetupClient(
|
||||||
|
s.cryptoStream,
|
||||||
|
hostname,
|
||||||
|
s.connectionID,
|
||||||
|
s.version,
|
||||||
|
tlsConf,
|
||||||
|
transportParams,
|
||||||
|
paramsChan,
|
||||||
|
aeadChanged,
|
||||||
|
initialVersion,
|
||||||
|
negotiatedVersions,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.cryptoSetup = cs
|
||||||
|
return s, s.postSetup(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) setup(
|
func newTLSServerSession(
|
||||||
scfg *handshake.ServerConfig,
|
conn connection,
|
||||||
hostname string,
|
connectionID protocol.ConnectionID,
|
||||||
tlsConf *tls.Config,
|
initialPacketNumber protocol.PacketNumber,
|
||||||
initialVersion protocol.VersionNumber,
|
config *Config,
|
||||||
negotiatedVersions []protocol.VersionNumber,
|
tls handshake.MintTLS,
|
||||||
) error {
|
cryptoStreamConn *handshake.CryptoStreamConn,
|
||||||
|
nullAEAD crypto.AEAD,
|
||||||
|
peerParams *handshake.TransportParameters,
|
||||||
|
v protocol.VersionNumber,
|
||||||
|
) (packetHandler, error) {
|
||||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||||
paramsChan := make(chan handshake.TransportParameters)
|
s := &session{
|
||||||
s.aeadChanged = aeadChanged
|
conn: conn,
|
||||||
s.paramsChan = paramsChan
|
config: config,
|
||||||
|
connectionID: connectionID,
|
||||||
|
perspective: protocol.PerspectiveServer,
|
||||||
|
version: v,
|
||||||
|
aeadChanged: aeadChanged,
|
||||||
|
}
|
||||||
|
s.preSetup()
|
||||||
|
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
|
||||||
|
tls,
|
||||||
|
cryptoStreamConn,
|
||||||
|
nullAEAD,
|
||||||
|
aeadChanged,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
if err := s.postSetup(initialPacketNumber); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.peerParams = peerParams
|
||||||
|
s.processTransportParameters(peerParams)
|
||||||
|
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// declare this as a variable, such that we can it mock it in the tests
|
||||||
|
var newTLSClientSession = func(
|
||||||
|
conn connection,
|
||||||
|
hostname string,
|
||||||
|
v protocol.VersionNumber,
|
||||||
|
connectionID protocol.ConnectionID,
|
||||||
|
config *Config,
|
||||||
|
tls handshake.MintTLS,
|
||||||
|
paramsChan <-chan handshake.TransportParameters,
|
||||||
|
initialPacketNumber protocol.PacketNumber,
|
||||||
|
) (packetHandler, error) {
|
||||||
|
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||||
|
s := &session{
|
||||||
|
conn: conn,
|
||||||
|
config: config,
|
||||||
|
connectionID: connectionID,
|
||||||
|
perspective: protocol.PerspectiveClient,
|
||||||
|
version: v,
|
||||||
|
aeadChanged: aeadChanged,
|
||||||
|
paramsChan: paramsChan,
|
||||||
|
}
|
||||||
|
s.preSetup()
|
||||||
|
tls.SetCryptoStream(s.cryptoStream)
|
||||||
|
cs, err := handshake.NewCryptoSetupTLSClient(
|
||||||
|
s.cryptoStream,
|
||||||
|
s.connectionID,
|
||||||
|
hostname,
|
||||||
|
aeadChanged,
|
||||||
|
tls,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.cryptoSetup = cs
|
||||||
|
return s, s.postSetup(initialPacketNumber)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) preSetup() {
|
||||||
|
s.rttStats = &congestion.RTTStats{}
|
||||||
|
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||||
|
protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||||
|
s.rttStats,
|
||||||
|
)
|
||||||
|
s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
|
||||||
s.handshakeChan = make(chan handshakeEvent, 3)
|
s.handshakeChan = make(chan handshakeEvent, 3)
|
||||||
s.handshakeCompleteChan = make(chan error, 1)
|
s.handshakeCompleteChan = make(chan error, 1)
|
||||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||||
|
@ -185,91 +316,14 @@ func (s *session) setup(
|
||||||
s.lastNetworkActivityTime = now
|
s.lastNetworkActivityTime = now
|
||||||
s.sessionCreationTime = now
|
s.sessionCreationTime = now
|
||||||
|
|
||||||
s.rttStats = &congestion.RTTStats{}
|
|
||||||
transportParams := &handshake.TransportParameters{
|
|
||||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
|
||||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
|
||||||
MaxStreams: protocol.MaxIncomingStreams,
|
|
||||||
IdleTimeout: s.config.IdleTimeout,
|
|
||||||
}
|
|
||||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
||||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
|
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
|
||||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
|
||||||
protocol.ReceiveConnectionFlowControlWindow,
|
|
||||||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
|
||||||
s.rttStats,
|
|
||||||
)
|
|
||||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
|
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
|
||||||
s.cryptoStream = s.newStream(s.version.CryptoStreamID())
|
|
||||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController)
|
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController)
|
||||||
|
|
||||||
var err error
|
|
||||||
if s.perspective == protocol.PerspectiveServer {
|
|
||||||
verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool {
|
|
||||||
return s.config.AcceptCookie(clientAddr, cookie)
|
|
||||||
}
|
|
||||||
if s.version.UsesTLS() {
|
|
||||||
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
|
|
||||||
s.cryptoStream,
|
|
||||||
s.connectionID,
|
|
||||||
tlsConf,
|
|
||||||
s.conn.RemoteAddr(),
|
|
||||||
transportParams,
|
|
||||||
paramsChan,
|
|
||||||
aeadChanged,
|
|
||||||
verifySourceAddr,
|
|
||||||
s.config.Versions,
|
|
||||||
s.version,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
s.cryptoSetup, err = newCryptoSetup(
|
|
||||||
s.cryptoStream,
|
|
||||||
s.connectionID,
|
|
||||||
s.conn.RemoteAddr(),
|
|
||||||
s.version,
|
|
||||||
scfg,
|
|
||||||
transportParams,
|
|
||||||
s.config.Versions,
|
|
||||||
verifySourceAddr,
|
|
||||||
paramsChan,
|
|
||||||
aeadChanged,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
|
|
||||||
if s.version.UsesTLS() {
|
|
||||||
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
|
|
||||||
s.cryptoStream,
|
|
||||||
s.connectionID,
|
|
||||||
hostname,
|
|
||||||
tlsConf,
|
|
||||||
transportParams,
|
|
||||||
paramsChan,
|
|
||||||
aeadChanged,
|
|
||||||
initialVersion,
|
|
||||||
s.config.Versions,
|
|
||||||
s.version,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
s.cryptoSetup, err = newCryptoSetupClient(
|
|
||||||
s.cryptoStream,
|
|
||||||
hostname,
|
|
||||||
s.connectionID,
|
|
||||||
s.version,
|
|
||||||
tlsConf,
|
|
||||||
transportParams,
|
|
||||||
paramsChan,
|
|
||||||
aeadChanged,
|
|
||||||
initialVersion,
|
|
||||||
negotiatedVersions,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.packer = newPacketPacker(s.connectionID,
|
s.packer = newPacketPacker(s.connectionID,
|
||||||
|
initialPacketNumber,
|
||||||
s.cryptoSetup,
|
s.cryptoSetup,
|
||||||
s.streamFramer,
|
s.streamFramer,
|
||||||
s.perspective,
|
s.perspective,
|
||||||
|
@ -604,7 +658,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
|
||||||
s.cryptoStream.Cancel(quicErr)
|
s.cryptoStream.Cancel(quicErr)
|
||||||
s.streamsMap.CloseWithError(quicErr)
|
s.streamsMap.CloseWithError(quicErr)
|
||||||
|
|
||||||
if closeErr.err == errCloseSessionForNewVersion {
|
if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -893,6 +947,10 @@ func (s *session) handshakeStatus() <-chan handshakeEvent {
|
||||||
return s.handshakeChan
|
return s.handshakeChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *session) getCryptoStream() cryptoStream {
|
||||||
|
return s.cryptoStream
|
||||||
|
}
|
||||||
|
|
||||||
func (s *session) GetVersion() protocol.VersionNumber {
|
func (s *session) GetVersion() protocol.VersionNumber {
|
||||||
return s.version
|
return s.version
|
||||||
}
|
}
|
||||||
|
|
|
@ -468,9 +468,6 @@ var _ = Describe("Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("handles CONNECTION_CLOSE frames", func() {
|
It("handles CONNECTION_CLOSE frames", func() {
|
||||||
cryptoStream := mocks.NewMockStreamI(mockCtrl)
|
|
||||||
cryptoStream.EXPECT().Cancel(gomock.Any())
|
|
||||||
sess.cryptoStream = cryptoStream
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
@ -771,10 +768,15 @@ var _ = Describe("Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("sending packets", func() {
|
Context("sending packets", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||||
|
})
|
||||||
|
|
||||||
It("sends ACK frames", func() {
|
It("sends ACK frames", func() {
|
||||||
packetNumber := protocol.PacketNumber(0x035e)
|
packetNumber := protocol.PacketNumber(0x035e)
|
||||||
sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
|
err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
|
||||||
err := sess.sendPacket()
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
err = sess.sendPacket()
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(mconn.written).To(HaveLen(1))
|
Expect(mconn.written).To(HaveLen(1))
|
||||||
Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e}))))
|
Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e}))))
|
||||||
|
@ -858,6 +860,7 @@ var _ = Describe("Session", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
// a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet
|
// a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet
|
||||||
sess.packer.packetNumberGenerator.next = 0x1337 + 10
|
sess.packer.packetNumberGenerator.next = 0x1337 + 10
|
||||||
|
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||||
sph = newMockSentPacketHandler().(*mockSentPacketHandler)
|
sph = newMockSentPacketHandler().(*mockSentPacketHandler)
|
||||||
sess.sentPacketHandler = sph
|
sess.sentPacketHandler = sph
|
||||||
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
|
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
|
||||||
|
@ -981,6 +984,7 @@ var _ = Describe("Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("retransmits RTO packets", func() {
|
It("retransmits RTO packets", func() {
|
||||||
|
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||||
sess.sentPacketHandler.SetHandshakeComplete()
|
sess.sentPacketHandler.SetHandshakeComplete()
|
||||||
n := protocol.PacketNumber(10)
|
n := protocol.PacketNumber(10)
|
||||||
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
|
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
|
||||||
|
@ -1008,6 +1012,7 @@ var _ = Describe("Session", func() {
|
||||||
|
|
||||||
Context("scheduling sending", func() {
|
Context("scheduling sending", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||||
sess.processTransportParameters(&handshake.TransportParameters{
|
sess.processTransportParameters(&handshake.TransportParameters{
|
||||||
StreamFlowControlWindow: protocol.MaxByteCount,
|
StreamFlowControlWindow: protocol.MaxByteCount,
|
||||||
ConnectionFlowControlWindow: protocol.MaxByteCount,
|
ConnectionFlowControlWindow: protocol.MaxByteCount,
|
||||||
|
@ -1291,6 +1296,7 @@ var _ = Describe("Session", func() {
|
||||||
sess.handshakeComplete = true
|
sess.handshakeComplete = true
|
||||||
sess.config.KeepAlive = true
|
sess.config.KeepAlive = true
|
||||||
sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2)
|
sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2)
|
||||||
|
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||||
go sess.run()
|
go sess.run()
|
||||||
defer sess.Close(nil)
|
defer sess.Close(nil)
|
||||||
var data []byte
|
var data []byte
|
||||||
|
@ -1551,7 +1557,10 @@ var _ = Describe("Client Session", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("passes the diversification nonce to the cryptoSetup", func() {
|
It("passes the diversification nonce to the cryptoSetup", func() {
|
||||||
go sess.run()
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
sess.run()
|
||||||
|
}()
|
||||||
hdr.PacketNumber = 5
|
hdr.PacketNumber = 5
|
||||||
hdr.DiversificationNonce = []byte("foobar")
|
hdr.DiversificationNonce = []byte("foobar")
|
||||||
err := sess.handlePacketImpl(&receivedPacket{header: hdr})
|
err := sess.handlePacketImpl(&receivedPacket{header: hdr})
|
||||||
|
|
13
stream.go
13
stream.go
|
@ -32,6 +32,11 @@ type streamI interface {
|
||||||
IsFlowControlBlocked() bool
|
IsFlowControlBlocked() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cryptoStream interface {
|
||||||
|
streamI
|
||||||
|
SetReadOffset(protocol.ByteCount)
|
||||||
|
}
|
||||||
|
|
||||||
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
|
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
|
||||||
//
|
//
|
||||||
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
|
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
|
||||||
|
@ -481,3 +486,11 @@ func (s *stream) IsFlowControlBlocked() bool {
|
||||||
func (s *stream) GetWindowUpdate() protocol.ByteCount {
|
func (s *stream) GetWindowUpdate() protocol.ByteCount {
|
||||||
return s.flowController.GetWindowUpdate()
|
return s.flowController.GetWindowUpdate()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetReadOffset sets the read offset.
|
||||||
|
// It is only needed for the crypto stream.
|
||||||
|
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||||
|
func (s *stream) SetReadOffset(offset protocol.ByteCount) {
|
||||||
|
s.readOffset = offset
|
||||||
|
s.frameQueue.readPosition = offset
|
||||||
|
}
|
||||||
|
|
|
@ -266,6 +266,12 @@ var _ = Describe("Stream", func() {
|
||||||
Expect(onDataCalled).To(BeTrue())
|
Expect(onDataCalled).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("sets the read offset", func() {
|
||||||
|
str.SetReadOffset(0x42)
|
||||||
|
Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42)))
|
||||||
|
Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42)))
|
||||||
|
})
|
||||||
|
|
||||||
Context("deadlines", func() {
|
Context("deadlines", func() {
|
||||||
It("the deadline error has the right net.Error properties", func() {
|
It("the deadline error has the right net.Error properties", func() {
|
||||||
Expect(errDeadline.Temporary()).To(BeTrue())
|
Expect(errDeadline.Temporary()).To(BeTrue())
|
||||||
|
|
|
@ -325,4 +325,5 @@ func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
m.maxOutgoingStreams = limit
|
m.maxOutgoingStreams = limit
|
||||||
|
m.openStreamOrErrCond.Broadcast()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue