mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 12:47:36 +03:00
implement stateless handling of Initial packets for the TLS server
This commit is contained in:
parent
57c6f3ceb5
commit
25a6dc9654
36 changed files with 1617 additions and 724 deletions
150
client.go
150
client.go
|
@ -10,6 +10,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
|
@ -23,13 +24,17 @@ type client struct {
|
|||
hostname string
|
||||
|
||||
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
|
||||
versionNegotiated bool // has version negotiation completed yet
|
||||
versionNegotiated bool // has the server accepted our version
|
||||
receivedVersionNegotiationPacket bool
|
||||
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
tls handshake.MintTLS // only used when using TLS
|
||||
|
||||
connectionID protocol.ConnectionID
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
version protocol.VersionNumber
|
||||
|
||||
session packetHandler
|
||||
|
@ -91,7 +96,6 @@ func DialNonFWSecure(
|
|||
if tlsConf != nil {
|
||||
hostname = tlsConf.ServerName
|
||||
}
|
||||
|
||||
if hostname == "" {
|
||||
hostname, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
|
@ -111,8 +115,9 @@ func DialNonFWSecure(
|
|||
}
|
||||
|
||||
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
|
||||
go c.listen()
|
||||
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session.(NonFWSession), nil
|
||||
|
@ -177,25 +182,79 @@ func populateClientConfig(config *Config) *Config {
|
|||
}
|
||||
}
|
||||
|
||||
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
if err := c.createNewSession(c.version, nil); err != nil {
|
||||
func (c *client) dial() error {
|
||||
var err error
|
||||
if c.version.UsesTLS() {
|
||||
err = c.dialTLS()
|
||||
} else {
|
||||
err = c.dialGQUIC()
|
||||
}
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial()
|
||||
}
|
||||
return err
|
||||
}
|
||||
go c.listen()
|
||||
|
||||
func (c *client) dialGQUIC() error {
|
||||
if err := c.createNewGQUICSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.establishSecureConnection()
|
||||
}
|
||||
|
||||
func (c *client) dialTLS() error {
|
||||
csc := handshake.NewCryptoStreamConn(nil)
|
||||
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mintConf.ServerName = c.hostname
|
||||
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
|
||||
params := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
MaxStreams: protocol.MaxIncomingStreams,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
}
|
||||
eh := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
|
||||
if err := c.tls.SetExtensionHandler(eh); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
if err != handshake.ErrCloseSessionForRetry {
|
||||
return err
|
||||
}
|
||||
utils.Infof("Received a Retry packet. Recreating session.")
|
||||
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.establishSecureConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
// It returns:
|
||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
func (c *client) establishSecureConnection() error {
|
||||
var runErr error
|
||||
errorChan := make(chan struct{})
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
runErr = c.session.run()
|
||||
if runErr == errCloseSessionForNewVersion {
|
||||
// run the new session
|
||||
runErr = c.session.run()
|
||||
}
|
||||
runErr = c.session.run() // returns as soon as the session is closed
|
||||
close(errorChan)
|
||||
utils.Infof("Connection %x closed.", c.connectionID)
|
||||
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
|
||||
c.conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// wait until the server accepts the QUIC version (or an error occurs)
|
||||
|
@ -219,7 +278,8 @@ func (c *client) establishSecureConnection() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Listen listens
|
||||
// Listen listens on the underlying connection and passes packets on for handling.
|
||||
// It returns when the connection is closed.
|
||||
func (c *client) listen() {
|
||||
var err error
|
||||
|
||||
|
@ -233,13 +293,15 @@ func (c *client) listen() {
|
|||
n, addr, err = c.conn.Read(data)
|
||||
if err != nil {
|
||||
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
c.mutex.Lock()
|
||||
if c.session != nil {
|
||||
c.session.Close(err)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
break
|
||||
}
|
||||
data = data[:n]
|
||||
|
||||
c.handlePacket(addr, data)
|
||||
c.handlePacket(addr, data[:n])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -257,15 +319,16 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
|
||||
return
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||
return
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// reject packets with the wrong connection ID
|
||||
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.ResetFlag {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
|
@ -305,6 +368,8 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
|
|||
close(c.versionNegotiationChan)
|
||||
}
|
||||
|
||||
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
|
||||
|
||||
c.session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
|
@ -323,15 +388,15 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
}
|
||||
}
|
||||
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
|
||||
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
|
||||
if !ok {
|
||||
return qerr.InvalidVersion
|
||||
}
|
||||
c.receivedVersionNegotiationPacket = true
|
||||
c.negotiatedVersions = hdr.SupportedVersions
|
||||
|
||||
// switch to negotiated version
|
||||
initialVersion := c.version
|
||||
c.initialVersion = c.version
|
||||
c.version = newVersion
|
||||
var err error
|
||||
c.connectionID, err = utils.GenerateConnectionID()
|
||||
|
@ -339,17 +404,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
|||
return err
|
||||
}
|
||||
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
|
||||
|
||||
// create a new session and close the old one
|
||||
// the new session must be created first to update client member variables
|
||||
oldSession := c.session
|
||||
defer oldSession.Close(errCloseSessionForNewVersion)
|
||||
return c.createNewSession(initialVersion, hdr.SupportedVersions)
|
||||
c.session.Close(errCloseSessionForNewVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error {
|
||||
var err error
|
||||
utils.Debugf("createNewSession with initial version %s", initialVersion)
|
||||
func (c *client) createNewGQUICSession() (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
|
@ -357,8 +418,27 @@ func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotia
|
|||
c.connectionID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
initialVersion,
|
||||
negotiatedVersions,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
version protocol.VersionNumber,
|
||||
) (err error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.session, err = newTLSClientSession(
|
||||
c.conn,
|
||||
c.hostname,
|
||||
c.version,
|
||||
c.connectionID,
|
||||
c.config,
|
||||
c.tls,
|
||||
paramsChan,
|
||||
1,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
|
130
client_test.go
130
client_test.go
|
@ -8,6 +8,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
@ -45,10 +46,9 @@ var _ = Describe("Client", func() {
|
|||
msess, _ := newMockSession(nil, 0, 0, nil, nil, nil)
|
||||
sess = msess.(*mockSession)
|
||||
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
packetConn = &mockPacketConn{
|
||||
addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234},
|
||||
dataReadFrom: addr,
|
||||
}
|
||||
packetConn = newMockPacketConn()
|
||||
packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
||||
packetConn.dataReadFrom = addr
|
||||
config = &Config{
|
||||
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ var _ = Describe("Client", func() {
|
|||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, error) {
|
||||
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
|
||||
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
|
||||
return sess, nil
|
||||
}
|
||||
origGenerateConnectionID = generateConnectionID
|
||||
|
@ -101,7 +101,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("dials non-forward-secure", func() {
|
||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -151,7 +151,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
|
||||
It("Dial only returns after the handshake is complete", func() {
|
||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -237,7 +237,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("returns an error that occurs while waiting for the connection to become secure", func() {
|
||||
testErr := errors.New("early handshake error")
|
||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -251,7 +251,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("returns an error that occurs while waiting for the handshake to complete", func() {
|
||||
testErr := errors.New("late handshake error")
|
||||
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -330,10 +330,6 @@ var _ = Describe("Client", func() {
|
|||
newVersion := protocol.VersionNumber(77)
|
||||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(config.Versions).To(ContainElement(newVersion))
|
||||
packetConn.dataToRead = wire.ComposeGQUICVersionNegotiation(
|
||||
cl.connectionID,
|
||||
[]protocol.VersionNumber{newVersion},
|
||||
)
|
||||
sessionChan := make(chan *mockSession)
|
||||
handshakeChan := make(chan handshakeEvent)
|
||||
newClientSession = func(
|
||||
|
@ -348,10 +344,7 @@ var _ = Describe("Client", func() {
|
|||
) (packetHandler, error) {
|
||||
initialVersion = initialVersionP
|
||||
negotiatedVersions = negotiatedVersionsP
|
||||
// make the server accept the new version
|
||||
if len(negotiatedVersionsP) > 0 {
|
||||
packetConn.dataToRead = acceptClientVersionPacket(connectionID)
|
||||
}
|
||||
|
||||
sess := &mockSession{
|
||||
connectionID: connectionID,
|
||||
stopRunLoop: make(chan struct{}),
|
||||
|
@ -364,18 +357,26 @@ var _ = Describe("Client", func() {
|
|||
established := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cl.establishSecureConnection()
|
||||
err := cl.dial()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(established)
|
||||
}()
|
||||
go cl.listen()
|
||||
|
||||
actualInitialVersion := cl.version
|
||||
var firstSession, secondSession *mockSession
|
||||
Eventually(sessionChan).Should(Receive(&firstSession))
|
||||
Eventually(sessionChan).Should(Receive(&secondSession))
|
||||
packetConn.dataToRead <- wire.ComposeGQUICVersionNegotiation(
|
||||
cl.connectionID,
|
||||
[]protocol.VersionNumber{newVersion},
|
||||
)
|
||||
// it didn't pass the version negoation packet to the old session (since it has no payload)
|
||||
Expect(firstSession.packetCount).To(BeZero())
|
||||
Eventually(func() bool { return firstSession.closed }).Should(BeTrue())
|
||||
Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion))
|
||||
Expect(firstSession.packetCount).To(BeZero())
|
||||
Eventually(sessionChan).Should(Receive(&secondSession))
|
||||
// make the server accept the new version
|
||||
packetConn.dataToRead <- acceptClientVersionPacket(secondSession.connectionID)
|
||||
Consistently(func() bool { return secondSession.closed }).Should(BeFalse())
|
||||
Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337))
|
||||
Expect(negotiatedVersions).To(ContainElement(newVersion))
|
||||
|
@ -398,20 +399,23 @@ var _ = Describe("Client", func() {
|
|||
_ []protocol.VersionNumber,
|
||||
) (packetHandler, error) {
|
||||
atomic.AddUint32(&sessionCounter, 1)
|
||||
return sess, nil
|
||||
return &mockSession{
|
||||
connectionID: connectionID,
|
||||
stopRunLoop: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
go cl.establishSecureConnection()
|
||||
go cl.dial()
|
||||
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1))
|
||||
newVersion := protocol.VersionNumber(77)
|
||||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(config.Versions).To(ContainElement(newVersion))
|
||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
|
||||
Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2))
|
||||
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
|
||||
newVersion = protocol.VersionNumber(78)
|
||||
Expect(newVersion).ToNot(Equal(cl.version))
|
||||
Expect(config.Versions).To(ContainElement(newVersion))
|
||||
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
|
||||
Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2))
|
||||
Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
|
||||
})
|
||||
|
||||
It("errors if no matching version is found", func() {
|
||||
|
@ -482,7 +486,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(sess.closed).To(BeFalse())
|
||||
})
|
||||
|
||||
It("creates new sessions with the right parameters", func() {
|
||||
It("creates new GQUIC sessions with the right parameters", func() {
|
||||
closeErr := errors.New("peer doesn't reply")
|
||||
c := make(chan struct{})
|
||||
var cconn connection
|
||||
|
@ -516,12 +520,84 @@ var _ = Describe("Client", func() {
|
|||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||
Expect(hostname).To(Equal("quic.clemente.io"))
|
||||
Expect(version).To(Equal(cl.version))
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
sess.Close(closeErr)
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("creates new TLS sessions with the right parameters", func() {
|
||||
config.Versions = []protocol.VersionNumber{protocol.VersionTLS}
|
||||
c := make(chan struct{})
|
||||
var cconn connection
|
||||
var hostname string
|
||||
var version protocol.VersionNumber
|
||||
var conf *Config
|
||||
newTLSClientSession = func(
|
||||
connP connection,
|
||||
hostnameP string,
|
||||
versionP protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
configP *Config,
|
||||
tls handshake.MintTLS,
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
_ protocol.PacketNumber,
|
||||
) (packetHandler, error) {
|
||||
cconn = connP
|
||||
hostname = hostnameP
|
||||
version = versionP
|
||||
conf = configP
|
||||
close(c)
|
||||
return sess, nil
|
||||
}
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
close(dialed)
|
||||
}()
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||
Expect(hostname).To(Equal("quic.clemente.io"))
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
sess.Close(errors.New("peer doesn't reply"))
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("creates a new session when the server performs a retry", func() {
|
||||
config.Versions = []protocol.VersionNumber{protocol.VersionTLS}
|
||||
sessionChan := make(chan *mockSession)
|
||||
newTLSClientSession = func(
|
||||
connP connection,
|
||||
hostnameP string,
|
||||
versionP protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
configP *Config,
|
||||
tls handshake.MintTLS,
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
_ protocol.PacketNumber,
|
||||
) (packetHandler, error) {
|
||||
sess := &mockSession{
|
||||
stopRunLoop: make(chan struct{}),
|
||||
}
|
||||
sessionChan <- sess
|
||||
return sess, nil
|
||||
}
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
close(dialed)
|
||||
}()
|
||||
var firstSession, secondSession *mockSession
|
||||
Eventually(sessionChan).Should(Receive(&firstSession))
|
||||
firstSession.Close(handshake.ErrCloseSessionForRetry)
|
||||
Eventually(sessionChan).Should(Receive(&secondSession))
|
||||
secondSession.Close(errors.New("stop test"))
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
})
|
||||
|
||||
Context("handling packets", func() {
|
||||
It("handles packets", func() {
|
||||
ph := wire.Header{
|
||||
|
@ -532,7 +608,7 @@ var _ = Describe("Client", func() {
|
|||
b := &bytes.Buffer{}
|
||||
err := ph.Write(b, protocol.PerspectiveServer, cl.version)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
packetConn.dataToRead = b.Bytes()
|
||||
packetConn.dataToRead <- b.Bytes()
|
||||
|
||||
Expect(sess.packetCount).To(BeZero())
|
||||
stoppedListening := make(chan struct{})
|
||||
|
|
31
conn_test.go
31
conn_test.go
|
@ -2,7 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
type mockPacketConn struct {
|
||||
addr net.Addr
|
||||
dataToRead []byte
|
||||
dataToRead chan []byte
|
||||
dataReadFrom net.Addr
|
||||
readErr error
|
||||
dataWritten bytes.Buffer
|
||||
|
@ -20,23 +20,34 @@ type mockPacketConn struct {
|
|||
closed bool
|
||||
}
|
||||
|
||||
func newMockPacketConn() *mockPacketConn {
|
||||
return &mockPacketConn{
|
||||
dataToRead: make(chan []byte, 1000),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
if c.readErr != nil {
|
||||
return 0, nil, c.readErr
|
||||
}
|
||||
if c.dataToRead == nil { // block if there's no data
|
||||
time.Sleep(time.Hour)
|
||||
return 0, nil, io.EOF
|
||||
data, ok := <-c.dataToRead
|
||||
if !ok {
|
||||
return 0, nil, errors.New("connection closed")
|
||||
}
|
||||
n := copy(b, c.dataToRead)
|
||||
c.dataToRead = nil
|
||||
n := copy(b, data)
|
||||
return n, c.dataReadFrom, nil
|
||||
}
|
||||
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
c.dataWrittenTo = addr
|
||||
return c.dataWritten.Write(b)
|
||||
}
|
||||
func (c *mockPacketConn) Close() error { c.closed = true; return nil }
|
||||
func (c *mockPacketConn) Close() error {
|
||||
if !c.closed {
|
||||
close(c.dataToRead)
|
||||
}
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
|
||||
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
|
||||
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
|
||||
|
@ -53,7 +64,7 @@ var _ = Describe("Connection", func() {
|
|||
IP: net.IPv4(192, 168, 100, 200),
|
||||
Port: 1337,
|
||||
}
|
||||
packetConn = &mockPacketConn{}
|
||||
packetConn = newMockPacketConn()
|
||||
c = &conn{
|
||||
currentAddr: addr,
|
||||
pconn: packetConn,
|
||||
|
@ -68,7 +79,7 @@ var _ = Describe("Connection", func() {
|
|||
})
|
||||
|
||||
It("reads", func() {
|
||||
packetConn.dataToRead = []byte("foo")
|
||||
packetConn.dataToRead <- []byte("foo")
|
||||
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
|
||||
p := make([]byte, 10)
|
||||
n, raddr, err := c.Read(p)
|
||||
|
|
|
@ -7,33 +7,33 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type cookieHandler struct {
|
||||
type CookieHandler struct {
|
||||
callback func(net.Addr, *Cookie) bool
|
||||
|
||||
cookieGenerator *CookieGenerator
|
||||
}
|
||||
|
||||
var _ mint.CookieHandler = &cookieHandler{}
|
||||
var _ mint.CookieHandler = &CookieHandler{}
|
||||
|
||||
func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) {
|
||||
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
|
||||
cookieGenerator, err := NewCookieGenerator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cookieHandler{
|
||||
return &CookieHandler{
|
||||
callback: callback,
|
||||
cookieGenerator: cookieGenerator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
|
||||
if h.callback(conn.RemoteAddr(), nil) {
|
||||
return nil, nil
|
||||
}
|
||||
return h.cookieGenerator.NewToken(conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
|
||||
data, err := h.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
|
||||
|
|
|
@ -2,6 +2,7 @@ package handshake
|
|||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
|
||||
|
@ -9,22 +10,37 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
var _ net.Conn = &mockConn{}
|
||||
|
||||
func (c *mockConn) Read([]byte) (int, error) { panic("not implemented") }
|
||||
func (c *mockConn) Write([]byte) (int, error) { panic("not implemented") }
|
||||
func (c *mockConn) Close() error { panic("not implemented") }
|
||||
func (c *mockConn) LocalAddr() net.Addr { panic("not implemented") }
|
||||
func (c *mockConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *mockConn) SetReadDeadline(time.Time) error { panic("not implemented") }
|
||||
func (c *mockConn) SetWriteDeadline(time.Time) error { panic("not implemented") }
|
||||
func (c *mockConn) SetDeadline(time.Time) error { panic("not implemented") }
|
||||
|
||||
var callbackReturn bool
|
||||
var mockCallback = func(net.Addr, *Cookie) bool {
|
||||
return callbackReturn
|
||||
}
|
||||
|
||||
var _ = Describe("Cookie Handler", func() {
|
||||
var ch *cookieHandler
|
||||
var ch *CookieHandler
|
||||
var conn *mint.Conn
|
||||
|
||||
BeforeEach(func() {
|
||||
callbackReturn = false
|
||||
var err error
|
||||
ch, err = newCookieHandler(mockCallback)
|
||||
ch, err = NewCookieHandler(mockCallback)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46}
|
||||
conn = mint.NewConn(&fakeConn{remoteAddr: addr}, &mint.Config{}, false)
|
||||
conn = mint.NewConn(&mockConn{remoteAddr: addr}, &mint.Config{}, false)
|
||||
})
|
||||
|
||||
It("generates and validates a token", func() {
|
||||
|
|
|
@ -381,10 +381,6 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
|
|||
h.divNonceChan <- data
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType {
|
||||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sendCHLO() error {
|
||||
h.clientHelloCounter++
|
||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||
|
|
|
@ -458,10 +458,6 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
|
|||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
|
||||
panic("not needed for cryptoSetupServer")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
||||
if len(nonce) != 32 {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
|
@ -12,6 +11,9 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
|
||||
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
|
||||
|
||||
// KeyDerivationFunction is used for key derivation
|
||||
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
|
@ -20,68 +22,31 @@ type cryptoSetupTLS struct {
|
|||
|
||||
perspective protocol.Perspective
|
||||
|
||||
tls mintTLS
|
||||
conn *fakeConn
|
||||
|
||||
nextPacketType protocol.PacketType
|
||||
|
||||
keyDerivation KeyDerivationFunction
|
||||
nullAEAD crypto.AEAD
|
||||
aead crypto.AEAD
|
||||
|
||||
tls MintTLS
|
||||
cryptoStream *CryptoStreamConn
|
||||
aeadChanged chan<- protocol.EncryptionLevel
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
|
||||
func NewCryptoSetupTLSServer(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
tlsConfig *tls.Config,
|
||||
remoteAddr net.Addr,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
tls MintTLS,
|
||||
cryptoStream *CryptoStreamConn,
|
||||
nullAEAD crypto.AEAD,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
checkCookie func(net.Addr, *Cookie) bool,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mintConf.RequireCookie = true
|
||||
mintConf.CookieHandler, err = newCookieHandler(checkCookie)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mintConf.CookieProtector, err = mint.NewDefaultCookieProtector()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := &fakeConn{
|
||||
stream: cryptoStream,
|
||||
pers: protocol.PerspectiveServer,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
mintConn := mint.Server(conn, mintConf)
|
||||
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
) CryptoSetup {
|
||||
return &cryptoSetupTLS{
|
||||
perspective: protocol.PerspectiveServer,
|
||||
tls: &mintController{mintConn},
|
||||
conn: conn,
|
||||
tls: tls,
|
||||
cryptoStream: cryptoStream,
|
||||
nullAEAD: nullAEAD,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
|
||||
|
@ -89,60 +54,44 @@ func NewCryptoSetupTLSClient(
|
|||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
hostname string,
|
||||
tlsConfig *tls.Config,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
aeadChanged chan<- protocol.EncryptionLevel,
|
||||
initialVersion protocol.VersionNumber,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
tls MintTLS,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, error) {
|
||||
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mintConf.ServerName = hostname
|
||||
conn := &fakeConn{
|
||||
stream: cryptoStream,
|
||||
pers: protocol.PerspectiveClient,
|
||||
}
|
||||
mintConn := mint.Client(conn, mintConf)
|
||||
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
|
||||
if err := mintConn.SetExtensionHandler(eh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cryptoSetupTLS{
|
||||
conn: conn,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
tls: &mintController{mintConn},
|
||||
tls: tls,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
aeadChanged: aeadChanged,
|
||||
nextPacketType: protocol.PacketTypeInitial,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||
handshakeLoop:
|
||||
for {
|
||||
switch alert := h.tls.Handshake(); alert {
|
||||
case mint.AlertStatelessRetry:
|
||||
case mint.AlertNoAlert: // handshake complete
|
||||
break handshakeLoop
|
||||
case mint.AlertWouldBlock:
|
||||
h.determineNextPacketType()
|
||||
if err := h.conn.Continue(); err != nil {
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
|
||||
// send out that data now
|
||||
if _, err := h.cryptoStream.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
handshakeLoop:
|
||||
for {
|
||||
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
}
|
||||
switch h.tls.State() {
|
||||
case mint.StateClientStart: // this happens if a stateless retry is performed
|
||||
return ErrCloseSessionForRetry
|
||||
case mint.StateClientConnected, mint.StateServerConnected:
|
||||
break handshakeLoop
|
||||
}
|
||||
}
|
||||
|
||||
aead, err := h.keyDerivation(h.tls, h.perspective)
|
||||
|
@ -209,35 +158,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
|
|||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) determineNextPacketType() error {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
state := h.tls.State().HandshakeState
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
switch state {
|
||||
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest
|
||||
h.nextPacketType = protocol.PacketTypeRetry
|
||||
case "ServerStateWaitFinished":
|
||||
h.nextPacketType = protocol.PacketTypeHandshake
|
||||
default:
|
||||
// TODO: accept 0-RTT data
|
||||
return fmt.Errorf("Unexpected handshake state: %s", state)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// client
|
||||
if state != "ClientStateWaitSH" {
|
||||
h.nextPacketType = protocol.PacketTypeHandshake
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
return h.nextPacketType
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
|
||||
panic("diversification nonce not needed for TLS")
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
|
@ -10,7 +9,6 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -23,52 +21,33 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
|
|||
var _ = Describe("TLS Crypto Setup", func() {
|
||||
var (
|
||||
cs *cryptoSetupTLS
|
||||
paramsChan chan TransportParameters
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
paramsChan = make(chan TransportParameters)
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
csInt, err := NewCryptoSetupTLSServer(
|
||||
cs = NewCryptoSetupTLSServer(
|
||||
nil,
|
||||
1,
|
||||
testdata.GetTLSConfig(),
|
||||
nil,
|
||||
&TransportParameters{},
|
||||
paramsChan,
|
||||
NewCryptoStreamConn(nil),
|
||||
nil, // AEAD
|
||||
aeadChanged,
|
||||
nil,
|
||||
nil,
|
||||
protocol.VersionTLS,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cs = csInt.(*cryptoSetupTLS)
|
||||
).(*cryptoSetupTLS)
|
||||
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
||||
})
|
||||
|
||||
It("errors when the handshake fails", func() {
|
||||
alert := mint.AlertBadRecordMAC
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(alert)
|
||||
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(alert)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
|
||||
})
|
||||
|
||||
It("continues shaking hands when mint says that it would block", func() {
|
||||
cs.conn.stream = &bytes.Buffer{}
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{})
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("derives keys", func() {
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -76,64 +55,22 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
Expect(aeadChanged).To(BeClosed())
|
||||
})
|
||||
|
||||
Context("determining the packet type", func() {
|
||||
Context("for the client", func() {
|
||||
var csClient *cryptoSetupTLS
|
||||
|
||||
BeforeEach(func() {
|
||||
csInt, err := NewCryptoSetupTLSClient(
|
||||
nil,
|
||||
1,
|
||||
"quic.clemente.io",
|
||||
testdata.GetTLSConfig(),
|
||||
&TransportParameters{},
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
protocol.VersionTLS,
|
||||
[]protocol.VersionNumber{protocol.VersionTLS},
|
||||
protocol.VersionTLS,
|
||||
)
|
||||
It("handshakes until it is connected", func() {
|
||||
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerNegotiated).Times(9)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
csClient = csInt.(*cryptoSetupTLS)
|
||||
csClient.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
})
|
||||
|
||||
It("sends a Client Initial first", func() {
|
||||
Expect(csClient.GetNextPacketType()).To(Equal(protocol.PacketTypeInitial))
|
||||
})
|
||||
|
||||
It("sends a Handshake packet after the server sent a Server Hello", func() {
|
||||
csClient.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ClientStateWaitEE"})
|
||||
err := csClient.determineNextPacketType()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("for the server", func() {
|
||||
BeforeEach(func() {
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
})
|
||||
|
||||
It("sends a Stateless Retry packet", func() {
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateStart"})
|
||||
err := cs.determineNextPacketType()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeRetry))
|
||||
})
|
||||
|
||||
It("sends Handshake packet", func() {
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateWaitFinished"})
|
||||
err := cs.determineNextPacketType()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeHandshake))
|
||||
})
|
||||
})
|
||||
Expect(aeadChanged).To(Receive())
|
||||
})
|
||||
|
||||
Context("escalating crypto", func() {
|
||||
doHandshake := func() {
|
||||
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -240,3 +177,33 @@ var _ = Describe("TLS Crypto Setup", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("TLS Crypto Setup, for the client", func() {
|
||||
var (
|
||||
cs *cryptoSetupTLS
|
||||
aeadChanged chan protocol.EncryptionLevel
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
aeadChanged = make(chan protocol.EncryptionLevel, 2)
|
||||
csInt, err := NewCryptoSetupTLSClient(
|
||||
nil,
|
||||
0,
|
||||
"quic.clemente.io",
|
||||
aeadChanged,
|
||||
nil, // mintTLS
|
||||
protocol.VersionTLS,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cs = csInt.(*cryptoSetupTLS)
|
||||
})
|
||||
|
||||
It("returns when a retry is performed", func() {
|
||||
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateClientStart)
|
||||
err := cs.HandleCryptoStream()
|
||||
Expect(err).To(MatchError(ErrCloseSessionForRetry))
|
||||
})
|
||||
|
||||
})
|
||||
|
|
101
internal/handshake/crypto_stream_conn.go
Normal file
101
internal/handshake/crypto_stream_conn.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// The CryptoStreamConn is used as the net.Conn passed to mint.
|
||||
// It has two operating modes:
|
||||
// 1. It can read and write to bytes.Buffers.
|
||||
// 2. It can use a quic.Stream for reading and writing.
|
||||
// The buffer-mode is only used by the server, in order to statelessly handle retries.
|
||||
type CryptoStreamConn struct {
|
||||
remoteAddr net.Addr
|
||||
|
||||
// the buffers are used before the session is initialized
|
||||
readBuf bytes.Buffer
|
||||
writeBuf bytes.Buffer
|
||||
|
||||
// stream will be set once the session is initialized
|
||||
stream io.ReadWriter
|
||||
}
|
||||
|
||||
var _ net.Conn = &CryptoStreamConn{}
|
||||
|
||||
// NewCryptoStreamConn creates a new CryptoStreamConn
|
||||
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
|
||||
return &CryptoStreamConn{remoteAddr: remoteAddr}
|
||||
}
|
||||
|
||||
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
|
||||
if c.stream != nil {
|
||||
return c.stream.Read(b)
|
||||
}
|
||||
return c.readBuf.Read(b)
|
||||
}
|
||||
|
||||
// AddDataForReading adds data to the read buffer.
|
||||
// This data will ONLY be read when the stream has not been set.
|
||||
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
|
||||
c.readBuf.Write(data)
|
||||
}
|
||||
|
||||
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
|
||||
if c.stream != nil {
|
||||
return c.stream.Write(p)
|
||||
}
|
||||
return c.writeBuf.Write(p)
|
||||
}
|
||||
|
||||
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
|
||||
func (c *CryptoStreamConn) GetDataForWriting() []byte {
|
||||
defer c.writeBuf.Reset()
|
||||
data := make([]byte, c.writeBuf.Len())
|
||||
copy(data, c.writeBuf.Bytes())
|
||||
return data
|
||||
}
|
||||
|
||||
// SetStream sets the stream.
|
||||
// After setting the stream, the read and write buffer won't be used any more.
|
||||
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
|
||||
c.stream = stream
|
||||
}
|
||||
|
||||
// Flush copies the contents of the write buffer to the stream
|
||||
func (c *CryptoStreamConn) Flush() (int, error) {
|
||||
n, err := io.Copy(c.stream, &c.writeBuf)
|
||||
return int(n), err
|
||||
}
|
||||
|
||||
// Close is not implemented
|
||||
func (c *CryptoStreamConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr is not implemented
|
||||
func (c *CryptoStreamConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address
|
||||
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
// SetReadDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline is not implemented
|
||||
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
67
internal/handshake/crypto_stream_conn_test.go
Normal file
67
internal/handshake/crypto_stream_conn_test.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("CryptoStreamConn", func() {
|
||||
var (
|
||||
csc *CryptoStreamConn
|
||||
remoteAddr net.Addr
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
csc = NewCryptoStreamConn(remoteAddr)
|
||||
})
|
||||
|
||||
It("reads from the read buffer, when no stream is set", func() {
|
||||
csc.AddDataForReading([]byte("foobar"))
|
||||
data := make([]byte, 4)
|
||||
n, err := csc.Read(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(4))
|
||||
Expect(data).To(Equal([]byte("foob")))
|
||||
})
|
||||
|
||||
It("writes to the write buffer, when no stream is set", func() {
|
||||
csc.Write([]byte("foo"))
|
||||
Expect(csc.GetDataForWriting()).To(Equal([]byte("foo")))
|
||||
csc.Write([]byte("bar"))
|
||||
Expect(csc.GetDataForWriting()).To(Equal([]byte("bar")))
|
||||
})
|
||||
|
||||
It("reads from the stream, if available", func() {
|
||||
csc.stream = &bytes.Buffer{}
|
||||
csc.stream.Write([]byte("foobar"))
|
||||
data := make([]byte, 3)
|
||||
n, err := csc.Read(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(3))
|
||||
Expect(data).To(Equal([]byte("foo")))
|
||||
})
|
||||
|
||||
It("writes to the stream, if available", func() {
|
||||
stream := &bytes.Buffer{}
|
||||
csc.SetStream(stream)
|
||||
csc.Write([]byte("foobar"))
|
||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("returns the remote address", func() {
|
||||
Expect(csc.RemoteAddr()).To(Equal(remoteAddr))
|
||||
})
|
||||
|
||||
It("has unimplemented methods", func() {
|
||||
Expect(csc.Close()).ToNot(HaveOccurred())
|
||||
Expect(csc.SetDeadline(time.Time{})).ToNot(HaveOccurred())
|
||||
Expect(csc.SetReadDeadline(time.Time{})).ToNot(HaveOccurred())
|
||||
Expect(csc.SetWriteDeadline(time.Time{})).ToNot(HaveOccurred())
|
||||
Expect(csc.LocalAddr()).To(BeNil())
|
||||
})
|
||||
})
|
|
@ -1,6 +1,10 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
|
@ -10,6 +14,26 @@ type Sealer interface {
|
|||
Overhead() int
|
||||
}
|
||||
|
||||
// A TLSExtensionHandler sends and received the QUIC TLS extension.
|
||||
// It provides the parameters sent by the peer on a channel.
|
||||
type TLSExtensionHandler interface {
|
||||
Send(mint.HandshakeType, *mint.ExtensionList) error
|
||||
Receive(mint.HandshakeType, *mint.ExtensionList) error
|
||||
GetPeerParams() <-chan TransportParameters
|
||||
}
|
||||
|
||||
// MintTLS combines some methods needed to interact with mint.
|
||||
type MintTLS interface {
|
||||
crypto.TLSExporter
|
||||
|
||||
// additional methods
|
||||
Handshake() mint.Alert
|
||||
State() mint.State
|
||||
|
||||
SetCryptoStream(io.ReadWriter)
|
||||
SetExtensionHandler(mint.AppExtensionHandler) error
|
||||
}
|
||||
|
||||
// CryptoSetup is a crypto setup
|
||||
type CryptoSetup interface {
|
||||
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
|
@ -17,7 +41,6 @@ type CryptoSetup interface {
|
|||
// TODO: clean up this interface
|
||||
DiversificationNonce() []byte // only needed for cryptoSetupServer
|
||||
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
|
||||
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
gocrypto "crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
|
||||
mconf := &mint.Config{
|
||||
NonBlocking: true,
|
||||
CipherSuites: []mint.CipherSuite{
|
||||
mint.TLS_AES_128_GCM_SHA256,
|
||||
mint.TLS_AES_256_GCM_SHA384,
|
||||
},
|
||||
}
|
||||
if tlsConf != nil {
|
||||
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
|
||||
for i, certChain := range tlsConf.Certificates {
|
||||
mconf.Certificates[i] = &mint.Certificate{
|
||||
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
|
||||
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
|
||||
}
|
||||
for j, cert := range certChain.Certificate {
|
||||
c, err := x509.ParseCertificate(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mconf.Certificates[i].Chain[j] = c
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mconf, nil
|
||||
}
|
||||
|
||||
type mintTLS interface {
|
||||
// These two methods are the same as the crypto.TLSExporter interface.
|
||||
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
|
||||
GetCipherSuite() mint.CipherSuiteParams
|
||||
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
|
||||
// additional methods
|
||||
Handshake() mint.Alert
|
||||
State() mint.ConnectionState
|
||||
}
|
||||
|
||||
var _ crypto.TLSExporter = (mintTLS)(nil)
|
||||
|
||||
type mintController struct {
|
||||
conn *mint.Conn
|
||||
}
|
||||
|
||||
var _ mintTLS = &mintController{}
|
||||
|
||||
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
|
||||
return mc.conn.State().CipherSuite
|
||||
}
|
||||
|
||||
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
return mc.conn.ComputeExporter(label, context, keyLength)
|
||||
}
|
||||
|
||||
func (mc *mintController) Handshake() mint.Alert {
|
||||
return mc.conn.Handshake()
|
||||
}
|
||||
|
||||
func (mc *mintController) State() mint.ConnectionState {
|
||||
return mc.conn.State()
|
||||
}
|
||||
|
||||
// mint expects a net.Conn, but we're doing the handshake on a stream
|
||||
// so we wrap a stream such that implements a net.Conn
|
||||
type fakeConn struct {
|
||||
stream io.ReadWriter
|
||||
pers protocol.Perspective
|
||||
remoteAddr net.Addr
|
||||
|
||||
blockRead bool
|
||||
writeBuffer bytes.Buffer
|
||||
}
|
||||
|
||||
var _ net.Conn = &fakeConn{}
|
||||
|
||||
func (c *fakeConn) Read(b []byte) (int, error) {
|
||||
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
|
||||
return 0, nil
|
||||
}
|
||||
c.blockRead = true // block the next Read call
|
||||
return c.stream.Read(b)
|
||||
}
|
||||
|
||||
func (c *fakeConn) Write(p []byte) (int, error) {
|
||||
if c.pers == protocol.PerspectiveClient {
|
||||
return c.stream.Write(p)
|
||||
}
|
||||
// Buffer all writes by the server.
|
||||
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
|
||||
return c.writeBuffer.Write(p)
|
||||
}
|
||||
|
||||
func (c *fakeConn) Continue() error {
|
||||
c.blockRead = false
|
||||
if c.pers == protocol.PerspectiveClient {
|
||||
return nil
|
||||
}
|
||||
// write all contents of the write buffer to the stream.
|
||||
_, err := c.stream.Write(c.writeBuffer.Bytes())
|
||||
c.writeBuffer.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *fakeConn) Close() error { return nil }
|
||||
func (c *fakeConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *fakeConn) SetDeadline(time.Time) error { return nil }
|
|
@ -1,72 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Fake Conn", func() {
|
||||
var (
|
||||
c *fakeConn
|
||||
stream *bytes.Buffer
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stream = &bytes.Buffer{}
|
||||
c = &fakeConn{stream: stream}
|
||||
})
|
||||
|
||||
Context("Reading", func() {
|
||||
It("doesn't return any new data after one Read call", func() {
|
||||
stream.Write([]byte("foobar"))
|
||||
b := make([]byte, 3)
|
||||
_, err := c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal([]byte("foo")))
|
||||
n, err := c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(BeZero())
|
||||
})
|
||||
|
||||
It("allows more Read calls after unblocking", func() {
|
||||
stream.Write([]byte("foobar"))
|
||||
b := make([]byte, 3)
|
||||
_, err := c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = c.Continue()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = c.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(b).To(Equal([]byte("bar")))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Writing", func() {
|
||||
It("writes directly when acting as a client", func() {
|
||||
c.pers = protocol.PerspectiveClient
|
||||
_, err := c.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("only writes after flushing when acting as a server", func() {
|
||||
c.pers = protocol.PerspectiveServer
|
||||
_, err := c.Write([]byte("foobar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.Bytes()).To(BeEmpty())
|
||||
err = c.Continue()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
})
|
||||
|
||||
It("returns its remote address", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
c.remoteAddr = addr
|
||||
Expect(c.RemoteAddr()).To(Equal(addr))
|
||||
})
|
||||
})
|
|
@ -13,8 +13,8 @@ import (
|
|||
)
|
||||
|
||||
type extensionHandlerClient struct {
|
||||
params *TransportParameters
|
||||
paramsChan chan<- TransportParameters
|
||||
ourParams *TransportParameters
|
||||
paramsChan chan TransportParameters
|
||||
|
||||
initialVersion protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
|
@ -22,16 +22,17 @@ type extensionHandlerClient struct {
|
|||
}
|
||||
|
||||
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
|
||||
var _ TLSExtensionHandler = &extensionHandlerClient{}
|
||||
|
||||
func newExtensionHandlerClient(
|
||||
func NewExtensionHandlerClient(
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
initialVersion protocol.VersionNumber,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
version protocol.VersionNumber,
|
||||
) *extensionHandlerClient {
|
||||
) TLSExtensionHandler {
|
||||
paramsChan := make(chan TransportParameters, 1)
|
||||
return &extensionHandlerClient{
|
||||
params: params,
|
||||
ourParams: params,
|
||||
paramsChan: paramsChan,
|
||||
initialVersion: initialVersion,
|
||||
supportedVersions: supportedVersions,
|
||||
|
@ -46,7 +47,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
|
|||
|
||||
data, err := syntax.Marshal(clientHelloTransportParameters{
|
||||
InitialVersion: uint32(h.initialVersion),
|
||||
Parameters: h.params.getTransportParameters(),
|
||||
Parameters: h.ourParams.getTransportParameters(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -123,3 +124,7 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
|
|||
h.paramsChan <- *params
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
|
||||
return h.paramsChan
|
||||
}
|
||||
|
|
|
@ -15,13 +15,10 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
|
|||
var (
|
||||
handler *extensionHandlerClient
|
||||
el mint.ExtensionList
|
||||
paramsChan chan TransportParameters
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
|
||||
paramsChan = make(chan TransportParameters, 1)
|
||||
handler = newExtensionHandlerClient(&TransportParameters{}, paramsChan, protocol.VersionWhatever, nil, protocol.VersionWhatever)
|
||||
handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever).(*extensionHandlerClient)
|
||||
el = make(mint.ExtensionList, 0)
|
||||
})
|
||||
|
||||
|
@ -81,7 +78,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
|
|||
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var params TransportParameters
|
||||
Expect(paramsChan).To(Receive(¶ms))
|
||||
Expect(handler.GetPeerParams()).To(Receive(¶ms))
|
||||
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
|
||||
})
|
||||
|
||||
|
|
|
@ -14,26 +14,27 @@ import (
|
|||
)
|
||||
|
||||
type extensionHandlerServer struct {
|
||||
params *TransportParameters
|
||||
paramsChan chan<- TransportParameters
|
||||
ourParams *TransportParameters
|
||||
paramsChan chan TransportParameters
|
||||
|
||||
version protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
|
||||
var _ TLSExtensionHandler = &extensionHandlerServer{}
|
||||
|
||||
func newExtensionHandlerServer(
|
||||
func NewExtensionHandlerServer(
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
version protocol.VersionNumber,
|
||||
) *extensionHandlerServer {
|
||||
) TLSExtensionHandler {
|
||||
paramsChan := make(chan TransportParameters, 1)
|
||||
return &extensionHandlerServer{
|
||||
params: params,
|
||||
ourParams: params,
|
||||
paramsChan: paramsChan,
|
||||
version: version,
|
||||
supportedVersions: supportedVersions,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,7 +44,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
|
|||
}
|
||||
|
||||
transportParams := append(
|
||||
h.params.getTransportParameters(),
|
||||
h.ourParams.getTransportParameters(),
|
||||
// TODO(#855): generate a real token
|
||||
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
|
||||
)
|
||||
|
@ -105,3 +106,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
|
|||
h.paramsChan <- *params
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
|
||||
return h.paramsChan
|
||||
}
|
||||
|
|
|
@ -22,13 +22,10 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
|
|||
var (
|
||||
handler *extensionHandlerServer
|
||||
el mint.ExtensionList
|
||||
paramsChan chan TransportParameters
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
|
||||
paramsChan = make(chan TransportParameters, 1)
|
||||
handler = newExtensionHandlerServer(&TransportParameters{}, paramsChan, nil, protocol.VersionWhatever)
|
||||
handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever).(*extensionHandlerServer)
|
||||
el = make(mint.ExtensionList, 0)
|
||||
})
|
||||
|
||||
|
@ -91,7 +88,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
|
|||
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var params TransportParameters
|
||||
Expect(paramsChan).To(Receive(¶ms))
|
||||
Expect(handler.GetPeerParams()).To(Receive(¶ms))
|
||||
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
|
||||
})
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package mocks
|
||||
|
||||
//go:generate sh -c "mockgen -source=../handshake/mint_utils.go -package mockhandshake -destination handshake/mint_tls.go"
|
||||
//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
|
||||
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"
|
||||
|
|
|
@ -1,83 +1,106 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ../handshake/mint_utils.go
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS)
|
||||
|
||||
package mockhandshake
|
||||
|
||||
import (
|
||||
io "io"
|
||||
reflect "reflect"
|
||||
|
||||
mint "github.com/bifurcation/mint"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockmintTLS is a mock of mintTLS interface
|
||||
type MockmintTLS struct {
|
||||
// MockMintTLS is a mock of MintTLS interface
|
||||
type MockMintTLS struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockmintTLSMockRecorder
|
||||
recorder *MockMintTLSMockRecorder
|
||||
}
|
||||
|
||||
// MockmintTLSMockRecorder is the mock recorder for MockmintTLS
|
||||
type MockmintTLSMockRecorder struct {
|
||||
mock *MockmintTLS
|
||||
// MockMintTLSMockRecorder is the mock recorder for MockMintTLS
|
||||
type MockMintTLSMockRecorder struct {
|
||||
mock *MockMintTLS
|
||||
}
|
||||
|
||||
// NewMockmintTLS creates a new mock instance
|
||||
func NewMockmintTLS(ctrl *gomock.Controller) *MockmintTLS {
|
||||
mock := &MockmintTLS{ctrl: ctrl}
|
||||
mock.recorder = &MockmintTLSMockRecorder{mock}
|
||||
// NewMockMintTLS creates a new mock instance
|
||||
func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS {
|
||||
mock := &MockMintTLS{ctrl: ctrl}
|
||||
mock.recorder = &MockMintTLSMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (_m *MockmintTLS) EXPECT() *MockmintTLSMockRecorder {
|
||||
func (_m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder {
|
||||
return _m.recorder
|
||||
}
|
||||
|
||||
// GetCipherSuite mocks base method
|
||||
func (_m *MockmintTLS) GetCipherSuite() mint.CipherSuiteParams {
|
||||
ret := _m.ctrl.Call(_m, "GetCipherSuite")
|
||||
ret0, _ := ret[0].(mint.CipherSuiteParams)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetCipherSuite indicates an expected call of GetCipherSuite
|
||||
func (_mr *MockmintTLSMockRecorder) GetCipherSuite() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockmintTLS)(nil).GetCipherSuite))
|
||||
}
|
||||
|
||||
// ComputeExporter mocks base method
|
||||
func (_m *MockmintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
ret := _m.ctrl.Call(_m, "ComputeExporter", label, context, keyLength)
|
||||
func (_m *MockMintTLS) ComputeExporter(_param0 string, _param1 []byte, _param2 int) ([]byte, error) {
|
||||
ret := _m.ctrl.Call(_m, "ComputeExporter", _param0, _param1, _param2)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ComputeExporter indicates an expected call of ComputeExporter
|
||||
func (_mr *MockmintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockmintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
|
||||
func (_mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetCipherSuite mocks base method
|
||||
func (_m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams {
|
||||
ret := _m.ctrl.Call(_m, "GetCipherSuite")
|
||||
ret0, _ := ret[0].(mint.CipherSuiteParams)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetCipherSuite indicates an expected call of GetCipherSuite
|
||||
func (_mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite))
|
||||
}
|
||||
|
||||
// Handshake mocks base method
|
||||
func (_m *MockmintTLS) Handshake() mint.Alert {
|
||||
func (_m *MockMintTLS) Handshake() mint.Alert {
|
||||
ret := _m.ctrl.Call(_m, "Handshake")
|
||||
ret0, _ := ret[0].(mint.Alert)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Handshake indicates an expected call of Handshake
|
||||
func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake))
|
||||
func (_mr *MockMintTLSMockRecorder) Handshake() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake))
|
||||
}
|
||||
|
||||
// SetCryptoStream mocks base method
|
||||
func (_m *MockMintTLS) SetCryptoStream(_param0 io.ReadWriter) {
|
||||
_m.ctrl.Call(_m, "SetCryptoStream", _param0)
|
||||
}
|
||||
|
||||
// SetCryptoStream indicates an expected call of SetCryptoStream
|
||||
func (_mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0)
|
||||
}
|
||||
|
||||
// SetExtensionHandler mocks base method
|
||||
func (_m *MockMintTLS) SetExtensionHandler(_param0 mint.AppExtensionHandler) error {
|
||||
ret := _m.ctrl.Call(_m, "SetExtensionHandler", _param0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetExtensionHandler indicates an expected call of SetExtensionHandler
|
||||
func (_mr *MockMintTLSMockRecorder) SetExtensionHandler(arg0 interface{}) *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetExtensionHandler", reflect.TypeOf((*MockMintTLS)(nil).SetExtensionHandler), arg0)
|
||||
}
|
||||
|
||||
// State mocks base method
|
||||
func (_m *MockmintTLS) State() mint.ConnectionState {
|
||||
func (_m *MockMintTLS) State() mint.State {
|
||||
ret := _m.ctrl.Call(_m, "State")
|
||||
ret0, _ := ret[0].(mint.ConnectionState)
|
||||
ret0, _ := ret[0].(mint.State)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// State indicates an expected call of State
|
||||
func (_mr *MockmintTLSMockRecorder) State() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockmintTLS)(nil).State))
|
||||
func (_mr *MockMintTLSMockRecorder) State() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State))
|
||||
}
|
||||
|
|
71
internal/mocks/tls_extension_handler.go
Normal file
71
internal/mocks/tls_extension_handler.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler)
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
mint "github.com/bifurcation/mint"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
)
|
||||
|
||||
// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface
|
||||
type MockTLSExtensionHandler struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTLSExtensionHandlerMockRecorder
|
||||
}
|
||||
|
||||
// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler
|
||||
type MockTLSExtensionHandlerMockRecorder struct {
|
||||
mock *MockTLSExtensionHandler
|
||||
}
|
||||
|
||||
// NewMockTLSExtensionHandler creates a new mock instance
|
||||
func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler {
|
||||
mock := &MockTLSExtensionHandler{ctrl: ctrl}
|
||||
mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (_m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder {
|
||||
return _m.recorder
|
||||
}
|
||||
|
||||
// GetPeerParams mocks base method
|
||||
func (_m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters {
|
||||
ret := _m.ctrl.Call(_m, "GetPeerParams")
|
||||
ret0, _ := ret[0].(<-chan handshake.TransportParameters)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetPeerParams indicates an expected call of GetPeerParams
|
||||
func (_mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams))
|
||||
}
|
||||
|
||||
// Receive mocks base method
|
||||
func (_m *MockTLSExtensionHandler) Receive(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error {
|
||||
ret := _m.ctrl.Call(_m, "Receive", _param0, _param1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Receive indicates an expected call of Receive
|
||||
func (_mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1)
|
||||
}
|
||||
|
||||
// Send mocks base method
|
||||
func (_m *MockTLSExtensionHandler) Send(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error {
|
||||
ret := _m.ctrl.Call(_m, "Send", _param0, _param1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Send indicates an expected call of Send
|
||||
func (_mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call {
|
||||
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1)
|
||||
}
|
150
mint_utils.go
Normal file
150
mint_utils.go
Normal file
|
@ -0,0 +1,150 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
gocrypto "crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type mintController struct {
|
||||
csc *handshake.CryptoStreamConn
|
||||
conn *mint.Conn
|
||||
}
|
||||
|
||||
var _ handshake.MintTLS = &mintController{}
|
||||
|
||||
func newMintController(
|
||||
csc *handshake.CryptoStreamConn,
|
||||
mconf *mint.Config,
|
||||
pers protocol.Perspective,
|
||||
) handshake.MintTLS {
|
||||
var conn *mint.Conn
|
||||
if pers == protocol.PerspectiveClient {
|
||||
conn = mint.Client(csc, mconf)
|
||||
} else {
|
||||
conn = mint.Server(csc, mconf)
|
||||
}
|
||||
return &mintController{
|
||||
csc: csc,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
|
||||
return mc.conn.State().CipherSuite
|
||||
}
|
||||
|
||||
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
|
||||
return mc.conn.ComputeExporter(label, context, keyLength)
|
||||
}
|
||||
|
||||
func (mc *mintController) Handshake() mint.Alert {
|
||||
return mc.conn.Handshake()
|
||||
}
|
||||
|
||||
func (mc *mintController) State() mint.State {
|
||||
return mc.conn.State().HandshakeState
|
||||
}
|
||||
|
||||
func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
|
||||
mc.csc.SetStream(stream)
|
||||
}
|
||||
|
||||
func (mc *mintController) SetExtensionHandler(h mint.AppExtensionHandler) error {
|
||||
return mc.conn.SetExtensionHandler(h)
|
||||
}
|
||||
|
||||
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
|
||||
mconf := &mint.Config{
|
||||
NonBlocking: true,
|
||||
CipherSuites: []mint.CipherSuite{
|
||||
mint.TLS_AES_128_GCM_SHA256,
|
||||
mint.TLS_AES_256_GCM_SHA384,
|
||||
},
|
||||
}
|
||||
if tlsConf != nil {
|
||||
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
|
||||
for i, certChain := range tlsConf.Certificates {
|
||||
mconf.Certificates[i] = &mint.Certificate{
|
||||
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
|
||||
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
|
||||
}
|
||||
for j, cert := range certChain.Certificate {
|
||||
c, err := x509.ParseCertificate(cert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mconf.Certificates[i].Chain[j] = c
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mconf, nil
|
||||
}
|
||||
|
||||
// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets
|
||||
// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0.
|
||||
func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.StreamFrame, error) {
|
||||
unpacker := &packetUnpacker{aead: &nullAEAD{aead}, version: version}
|
||||
packet, err := unpacker.Unpack(hdr.Raw, hdr, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var frame *wire.StreamFrame
|
||||
for _, f := range packet.frames {
|
||||
var ok bool
|
||||
frame, ok = f.(*wire.StreamFrame)
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
if frame == nil {
|
||||
return nil, errors.New("Packet doesn't contain a STREAM_FRAME")
|
||||
}
|
||||
// We don't need a check for the stream ID here.
|
||||
// The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream.
|
||||
if frame.Offset != 0 {
|
||||
return nil, errors.New("received stream data with non-zero offset")
|
||||
}
|
||||
if utils.Debug() {
|
||||
utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID)
|
||||
hdr.Log()
|
||||
wire.LogFrame(frame, false)
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
// packUnencryptedPacket provides a low-overhead way to pack a packet.
|
||||
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
|
||||
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) {
|
||||
raw := getPacketBuffer()
|
||||
buffer := bytes.NewBuffer(raw)
|
||||
if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadStartIndex := buffer.Len()
|
||||
if err := sf.Write(buffer, hdr.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw = raw[0:buffer.Len()]
|
||||
_ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
|
||||
raw = raw[0 : buffer.Len()+aead.Overhead()]
|
||||
if utils.Debug() {
|
||||
utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted)
|
||||
hdr.Log()
|
||||
wire.LogFrame(sf, true)
|
||||
}
|
||||
return raw, nil
|
||||
}
|
111
mint_utils_test.go
Normal file
111
mint_utils_test.go
Normal file
|
@ -0,0 +1,111 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Packing and unpacking Initial packets", func() {
|
||||
var aead crypto.AEAD
|
||||
connID := protocol.ConnectionID(0x1337)
|
||||
ver := protocol.VersionTLS
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
PacketNumber: 0x42,
|
||||
ConnectionID: connID,
|
||||
Version: ver,
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
aead, err = crypto.NewNullAEAD(protocol.PerspectiveServer, connID, ver)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// set hdr.Raw
|
||||
buf := &bytes.Buffer{}
|
||||
err = hdr.Write(buf, protocol.PerspectiveServer, ver)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr.Raw = buf.Bytes()
|
||||
})
|
||||
|
||||
Context("unpacking", func() {
|
||||
packPacket := func(frames []wire.Frame) []byte {
|
||||
buf := &bytes.Buffer{}
|
||||
err := hdr.Write(buf, protocol.PerspectiveClient, ver)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
payloadStartIndex := buf.Len()
|
||||
aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver)
|
||||
for _, f := range frames {
|
||||
err := f.Write(buf, ver)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
raw := buf.Bytes()
|
||||
return aeadCl.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
|
||||
}
|
||||
|
||||
It("unpacks a packet", func() {
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 0,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
p := packPacket([]wire.Frame{f})
|
||||
frame, err := unpackInitialPacket(aead, hdr, p, ver)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(frame).To(Equal(f))
|
||||
})
|
||||
|
||||
It("rejects a packet that doesn't contain a STREAM_FRAME", func() {
|
||||
p := packPacket([]wire.Frame{&wire.PingFrame{}})
|
||||
_, err := unpackInitialPacket(aead, hdr, p, ver)
|
||||
Expect(err).To(MatchError("Packet doesn't contain a STREAM_FRAME"))
|
||||
})
|
||||
|
||||
It("rejects a packet that has a STREAM_FRAME for the wrong stream", func() {
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 42,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
p := packPacket([]wire.Frame{f})
|
||||
_, err := unpackInitialPacket(aead, hdr, p, ver)
|
||||
Expect(err).To(MatchError("UnencryptedStreamData: received unencrypted stream data on stream 42"))
|
||||
})
|
||||
|
||||
It("rejects a packet that has a STREAM_FRAME with a non-zero offset", func() {
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: 0,
|
||||
Offset: 10,
|
||||
Data: []byte("foobar"),
|
||||
}
|
||||
p := packPacket([]wire.Frame{f})
|
||||
_, err := unpackInitialPacket(aead, hdr, p, ver)
|
||||
Expect(err).To(MatchError("received stream data with non-zero offset"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("packing", func() {
|
||||
var unpacker *packetUnpacker
|
||||
|
||||
BeforeEach(func() {
|
||||
aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
unpacker = &packetUnpacker{aead: &nullAEAD{aeadCl}, version: ver}
|
||||
})
|
||||
|
||||
It("packs a packet", func() {
|
||||
f := &wire.StreamFrame{
|
||||
Data: []byte("foobar"),
|
||||
FinBit: true,
|
||||
}
|
||||
data, err := packUnencryptedPacket(aead, hdr, f, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
packet, err := unpacker.Unpack(hdr.Raw, hdr, data[len(hdr.Raw):])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(packet.frames).To(Equal([]wire.Frame{f}))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -17,9 +17,9 @@ type packetNumberGenerator struct {
|
|||
nextToSkip protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newPacketNumberGenerator(averagePeriod protocol.PacketNumber) *packetNumberGenerator {
|
||||
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
|
||||
return &packetNumberGenerator{
|
||||
next: 1,
|
||||
next: initial,
|
||||
averagePeriod: averagePeriod,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,7 +12,12 @@ var _ = Describe("Packet Number Generator", func() {
|
|||
var png packetNumberGenerator
|
||||
|
||||
BeforeEach(func() {
|
||||
png = *newPacketNumberGenerator(100)
|
||||
png = *newPacketNumberGenerator(1, 100)
|
||||
})
|
||||
|
||||
It("can be initialized to return any first packet number", func() {
|
||||
png = *newPacketNumberGenerator(12345, 100)
|
||||
Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345)))
|
||||
})
|
||||
|
||||
It("gets 1 as the first packet number", func() {
|
||||
|
|
|
@ -32,9 +32,11 @@ type packetPacker struct {
|
|||
ackFrame *wire.AckFrame
|
||||
leastUnacked protocol.PacketNumber
|
||||
omitConnectionID bool
|
||||
hasSentPacket bool // has the packetPacker already sent a packet
|
||||
}
|
||||
|
||||
func newPacketPacker(connectionID protocol.ConnectionID,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
cryptoSetup handshake.CryptoSetup,
|
||||
streamFramer *streamFramer,
|
||||
perspective protocol.Perspective,
|
||||
|
@ -46,7 +48,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
|
|||
perspective: perspective,
|
||||
version: version,
|
||||
streamFramer: streamFramer,
|
||||
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
|
||||
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,7 +118,12 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
|
|||
// PackPacket packs a new packet
|
||||
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
|
||||
func (p *packetPacker) PackPacket() (*packedPacket, error) {
|
||||
if p.streamFramer.HasCryptoStreamFrame() {
|
||||
hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame()
|
||||
// if this is the first packet to be send, make sure it contains stream data
|
||||
if !p.hasSentPacket && !hasCryptoStreamFrame {
|
||||
return nil, nil
|
||||
}
|
||||
if hasCryptoStreamFrame {
|
||||
return p.packCryptoPacket()
|
||||
}
|
||||
|
||||
|
@ -266,18 +273,21 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
|
|||
pnum := p.packetNumberGenerator.Peek()
|
||||
packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked)
|
||||
|
||||
var isLongHeader bool
|
||||
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
|
||||
// TODO: set the Long Header type
|
||||
packetNumberLen = protocol.PacketNumberLen4
|
||||
isLongHeader = true
|
||||
}
|
||||
|
||||
header := &wire.Header{
|
||||
ConnectionID: p.connectionID,
|
||||
PacketNumber: pnum,
|
||||
PacketNumberLen: packetNumberLen,
|
||||
IsLongHeader: isLongHeader,
|
||||
}
|
||||
|
||||
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
|
||||
header.PacketNumberLen = protocol.PacketNumberLen4
|
||||
header.IsLongHeader = true
|
||||
if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient {
|
||||
header.Type = protocol.PacketTypeInitial
|
||||
// TODO(#886): add padding
|
||||
} else {
|
||||
header.Type = protocol.PacketTypeHandshake
|
||||
}
|
||||
}
|
||||
|
||||
if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure {
|
||||
|
@ -292,7 +302,6 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
|
|||
header.Version = p.version
|
||||
}
|
||||
} else {
|
||||
header.Type = p.cryptoSetup.GetNextPacketType()
|
||||
if encLevel != protocol.EncryptionForwardSecure {
|
||||
header.Version = p.version
|
||||
}
|
||||
|
@ -330,7 +339,7 @@ func (p *packetPacker) writeAndSealPacket(
|
|||
if num != header.PacketNumber {
|
||||
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
|
||||
}
|
||||
|
||||
p.hasSentPacket = true
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ type mockCryptoSetup struct {
|
|||
divNonce []byte
|
||||
encLevelSeal protocol.EncryptionLevel
|
||||
encLevelSealCrypto protocol.EncryptionLevel
|
||||
nextPacketType protocol.PacketType
|
||||
}
|
||||
|
||||
var _ handshake.CryptoSetup = &mockCryptoSetup{}
|
||||
|
@ -50,7 +49,6 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
|
|||
}
|
||||
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
|
||||
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
|
||||
func (m *mockCryptoSetup) GetNextPacketType() protocol.PacketType { return m.nextPacketType }
|
||||
|
||||
var _ = Describe("Packet packer", func() {
|
||||
var (
|
||||
|
@ -69,13 +67,14 @@ var _ = Describe("Packet packer", func() {
|
|||
packer = &packetPacker{
|
||||
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
|
||||
connectionID: 0x1337,
|
||||
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
|
||||
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
|
||||
streamFramer: streamFramer,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
}
|
||||
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
|
||||
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
|
||||
packer.version = protocol.VersionWhatever
|
||||
packer.hasSentPacket = true
|
||||
})
|
||||
|
||||
It("returns nil when no packet is queued", func() {
|
||||
|
@ -191,13 +190,6 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(h.Version).To(Equal(versionIETFHeader))
|
||||
})
|
||||
|
||||
It("sets the packet type based on the state of the handshake", func() {
|
||||
packer.cryptoSetup.(*mockCryptoSetup).nextPacketType = 5
|
||||
h := packer.getHeader(protocol.EncryptionSecure)
|
||||
Expect(h.IsLongHeader).To(BeTrue())
|
||||
Expect(h.Type).To(Equal(protocol.PacketType(5)))
|
||||
})
|
||||
|
||||
It("uses the Short Header format for forward-secure packets", func() {
|
||||
h := packer.getHeader(protocol.EncryptionForwardSecure)
|
||||
Expect(h.IsLongHeader).To(BeFalse())
|
||||
|
@ -269,7 +261,7 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber))
|
||||
})
|
||||
|
||||
It("packs a StopWaitingFrame first", func() {
|
||||
It("packs a STOP_WAITING frame first", func() {
|
||||
packer.packetNumberGenerator.next = 15
|
||||
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
|
||||
packer.QueueControlFrame(&wire.RstStreamFrame{})
|
||||
|
@ -281,7 +273,7 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(p.frames[0]).To(Equal(swf))
|
||||
})
|
||||
|
||||
It("sets the LeastUnackedDelta length of a StopWaitingFrame", func() {
|
||||
It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() {
|
||||
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
|
||||
packer.packetNumberGenerator.next = packetNumber
|
||||
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
|
||||
|
@ -292,7 +284,7 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
|
||||
})
|
||||
|
||||
It("does not pack a packet containing only a StopWaitingFrame", func() {
|
||||
It("does not pack a packet containing only a STOP_WAITING frame", func() {
|
||||
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
|
||||
packer.QueueControlFrame(swf)
|
||||
p, err := packer.PackPacket()
|
||||
|
@ -307,6 +299,14 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(p).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() {
|
||||
packer.hasSentPacket = false
|
||||
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
|
||||
p, err := packer.PackPacket()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(p).To(BeNil())
|
||||
})
|
||||
|
||||
It("packs many control frames into 1 packets", func() {
|
||||
f := &wire.AckFrame{LargestAcked: 1}
|
||||
b := &bytes.Buffer{}
|
||||
|
@ -602,7 +602,7 @@ var _ = Describe("Packet packer", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("Blocked frames", func() {
|
||||
Context("BLOCKED frames", func() {
|
||||
It("queues a BLOCKED frame", func() {
|
||||
length := 100
|
||||
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
|
||||
|
@ -750,7 +750,7 @@ var _ = Describe("Packet packer", func() {
|
|||
Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment"))
|
||||
})
|
||||
|
||||
It("refuses to retransmit packets without a StopWaitingFrame", func() {
|
||||
It("refuses to retransmit packets without a STOP_WAITING Frame", func() {
|
||||
packer.stopWaiting = nil
|
||||
_, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{
|
||||
EncryptionLevel: protocol.EncryptionSecure,
|
||||
|
|
86
server.go
86
server.go
|
@ -19,6 +19,7 @@ import (
|
|||
// packetHandler handles packets
|
||||
type packetHandler interface {
|
||||
Session
|
||||
getCryptoStream() cryptoStream
|
||||
handshakeStatus() <-chan handshakeEvent
|
||||
handlePacket(*receivedPacket)
|
||||
GetVersion() protocol.VersionNumber
|
||||
|
@ -33,6 +34,9 @@ type server struct {
|
|||
|
||||
conn net.PacketConn
|
||||
|
||||
supportsTLS bool
|
||||
serverTLS *serverTLS
|
||||
|
||||
certChain crypto.CertChain
|
||||
scfg *handshake.ServerConfig
|
||||
|
||||
|
@ -77,11 +81,21 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config = populateServerConfig(config)
|
||||
|
||||
// check if any of the supported versions supports TLS
|
||||
var supportsTLS bool
|
||||
for _, v := range config.Versions {
|
||||
if v.UsesTLS() {
|
||||
supportsTLS = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s := &server{
|
||||
conn: conn,
|
||||
tlsConf: tlsConf,
|
||||
config: populateServerConfig(config),
|
||||
config: config,
|
||||
certChain: certChain,
|
||||
scfg: scfg,
|
||||
sessions: map[protocol.ConnectionID]packetHandler{},
|
||||
|
@ -89,12 +103,47 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
|
|||
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
|
||||
sessionQueue: make(chan Session, 5),
|
||||
errorChan: make(chan struct{}),
|
||||
supportsTLS: supportsTLS,
|
||||
}
|
||||
if supportsTLS {
|
||||
if err := s.setupTLS(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
go s.serve()
|
||||
utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) setupTLS() error {
|
||||
cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.serverTLS = serverTLS
|
||||
// handle TLS connection establishment statelessly
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.errorChan:
|
||||
return
|
||||
case sess := <-sessionChan:
|
||||
// TODO: think about what to do with connection ID collisions
|
||||
connID := sess.(*session).connectionID
|
||||
s.sessionsMutex.Lock()
|
||||
s.sessions[connID] = sess
|
||||
s.sessionsMutex.Unlock()
|
||||
s.runHandshakeAndSession(sess, connID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
|
||||
if cookie == nil {
|
||||
return false
|
||||
|
@ -225,8 +274,16 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
|||
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
|
||||
}
|
||||
hdr.Raw = packet[:len(packet)-r.Len()]
|
||||
packetData := packet[len(packet)-r.Len():]
|
||||
connID := hdr.ConnectionID
|
||||
|
||||
if hdr.Type == protocol.PacketTypeInitial {
|
||||
if s.supportsTLS {
|
||||
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.sessionsMutex.RLock()
|
||||
session, sessionKnown := s.sessions[connID]
|
||||
s.sessionsMutex.RUnlock()
|
||||
|
@ -279,11 +336,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
|||
return err
|
||||
}
|
||||
}
|
||||
// send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header
|
||||
if hdr.Type == protocol.PacketTypeInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
|
||||
_, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.config.Versions), remoteAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
if !sessionKnown {
|
||||
version := hdr.Version
|
||||
|
@ -307,9 +359,21 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
|||
s.sessions[connID] = session
|
||||
s.sessionsMutex.Unlock()
|
||||
|
||||
s.runHandshakeAndSession(session, connID)
|
||||
}
|
||||
session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packetData,
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) {
|
||||
go func() {
|
||||
// session.run() returns as soon as the session is closed
|
||||
_ = session.run()
|
||||
// session.run() returns as soon as the session is closed
|
||||
s.removeConnection(connID)
|
||||
}()
|
||||
|
||||
|
@ -326,14 +390,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
|||
s.sessionQueue <- session
|
||||
}()
|
||||
}
|
||||
session.handlePacket(&receivedPacket{
|
||||
remoteAddr: remoteAddr,
|
||||
header: hdr,
|
||||
data: packet[len(packet)-r.Len():],
|
||||
rcvTime: rcvTime,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) removeConnection(id protocol.ConnectionID) {
|
||||
s.sessionsMutex.Lock()
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
@ -67,6 +68,7 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not imple
|
|||
func (*mockSession) Context() context.Context { panic("not implemented") }
|
||||
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
|
||||
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan }
|
||||
func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") }
|
||||
|
||||
var _ Session = &mockSession{}
|
||||
var _ NonFWSession = &mockSession{}
|
||||
|
@ -96,7 +98,8 @@ var _ = Describe("Server", func() {
|
|||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
conn = &mockPacketConn{addr: &net.UDPAddr{}}
|
||||
conn = newMockPacketConn()
|
||||
conn.addr = &net.UDPAddr{}
|
||||
config = &Config{Versions: protocol.SupportedVersions}
|
||||
})
|
||||
|
||||
|
@ -235,14 +238,14 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("works if no quic.Config is given", func(done Done) {
|
||||
ln, err := ListenAddr("127.0.0.1:0", nil, config)
|
||||
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ln.Close()).To(Succeed())
|
||||
close(done)
|
||||
}, 1)
|
||||
|
||||
It("closes properly", func() {
|
||||
ln, err := ListenAddr("127.0.0.1:0", nil, config)
|
||||
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var returned bool
|
||||
|
@ -409,7 +412,7 @@ var _ = Describe("Server", func() {
|
|||
}
|
||||
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
||||
conn.dataToRead = b.Bytes()
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -432,7 +435,7 @@ var _ = Describe("Server", func() {
|
|||
})
|
||||
|
||||
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
|
||||
config.Versions = []protocol.VersionNumber{99}
|
||||
config.Versions = []protocol.VersionNumber{99, protocol.VersionTLS}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
|
@ -441,11 +444,12 @@ var _ = Describe("Server", func() {
|
|||
PacketNumber: 0x55,
|
||||
Version: 0x1234,
|
||||
}
|
||||
hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
||||
conn.dataToRead = b.Bytes()
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, nil, config)
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
@ -466,9 +470,32 @@ var _ = Describe("Server", func() {
|
|||
Consistently(done).ShouldNot(BeClosed())
|
||||
})
|
||||
|
||||
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
|
||||
version := protocol.VersionNumber(99)
|
||||
Expect(version.UsesTLS()).To(BeFalse())
|
||||
config.Versions = []protocol.VersionNumber{version}
|
||||
b := &bytes.Buffer{}
|
||||
hdr := wire.Header{
|
||||
Type: protocol.PacketTypeInitial,
|
||||
IsLongHeader: true,
|
||||
ConnectionID: 0x1337,
|
||||
PacketNumber: 0x55,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
|
||||
conn.dataToRead <- b.Bytes()
|
||||
conn.dataReadFrom = udpAddr
|
||||
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
|
||||
defer ln.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
|
||||
})
|
||||
|
||||
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
|
||||
conn.dataReadFrom = udpAddr
|
||||
conn.dataToRead = []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
|
||||
conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
|
||||
ln, err := Listen(conn, nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go func() {
|
||||
|
|
179
server_tls.go
Normal file
179
server_tls.go
Normal file
|
@ -0,0 +1,179 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type nullAEAD struct {
|
||||
aead crypto.AEAD
|
||||
}
|
||||
|
||||
var _ quicAEAD = &nullAEAD{}
|
||||
|
||||
func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||
data, err := n.aead.Open(dst, src, packetNumber, associatedData)
|
||||
return data, protocol.EncryptionUnencrypted, err
|
||||
}
|
||||
|
||||
type serverTLS struct {
|
||||
conn net.PacketConn
|
||||
config *Config
|
||||
supportedVersions []protocol.VersionNumber
|
||||
mintConf *mint.Config
|
||||
cookieProtector mint.CookieProtector
|
||||
params *handshake.TransportParameters
|
||||
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
|
||||
|
||||
sessionChan chan<- packetHandler
|
||||
}
|
||||
|
||||
func newServerTLS(
|
||||
conn net.PacketConn,
|
||||
config *Config,
|
||||
cookieHandler *handshake.CookieHandler,
|
||||
tlsConf *tls.Config,
|
||||
) (*serverTLS, <-chan packetHandler, error) {
|
||||
mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
mconf.RequireCookie = true
|
||||
cs, err := mint.NewDefaultCookieProtector()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
mconf.CookieProtector = cs
|
||||
mconf.CookieHandler = cookieHandler
|
||||
|
||||
sessionChan := make(chan packetHandler)
|
||||
s := &serverTLS{
|
||||
conn: conn,
|
||||
config: config,
|
||||
supportedVersions: config.Versions,
|
||||
mintConf: mconf,
|
||||
sessionChan: sessionChan,
|
||||
params: &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
MaxStreams: protocol.MaxIncomingStreams,
|
||||
IdleTimeout: config.IdleTimeout,
|
||||
},
|
||||
}
|
||||
s.newMintConn = s.newMintConnImpl
|
||||
return s, sessionChan, nil
|
||||
}
|
||||
|
||||
func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) {
|
||||
utils.Debugf("Received a Packet. Handling it statelessly.")
|
||||
sess, err := s.handleInitialImpl(remoteAddr, hdr, data)
|
||||
if err != nil {
|
||||
utils.Errorf("Error occured handling initial packet: %s", err)
|
||||
return
|
||||
}
|
||||
if sess == nil { // a stateless reset was done
|
||||
return
|
||||
}
|
||||
s.sessionChan <- sess
|
||||
}
|
||||
|
||||
// will be set to s.newMintConn by the constructor
|
||||
func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
|
||||
conn := mint.Server(bc, s.mintConf)
|
||||
extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v)
|
||||
if err := conn.SetExtensionHandler(extHandler); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tls := newMintController(bc, s.mintConf, protocol.PerspectiveServer)
|
||||
tls.SetExtensionHandler(extHandler)
|
||||
return tls, extHandler.GetPeerParams(), nil
|
||||
}
|
||||
|
||||
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
|
||||
// TODO: check length requirement
|
||||
// check version, if not matching send VNP
|
||||
if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
|
||||
utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
|
||||
_, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.supportedVersions), remoteAddr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// unpack packet and check stream frame contents
|
||||
version := hdr.Version
|
||||
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame, err := unpackInitialPacket(aead, hdr, data, version)
|
||||
if err != nil {
|
||||
utils.Debugf("Error unpacking initial packet: %s", err)
|
||||
return nil, nil
|
||||
}
|
||||
bc := handshake.NewCryptoStreamConn(remoteAddr)
|
||||
bc.AddDataForReading(frame.Data)
|
||||
tls, paramsChan, err := s.newMintConn(bc, hdr.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
alert := tls.Handshake()
|
||||
if alert == mint.AlertStatelessRetry {
|
||||
// the HelloRetryRequest was written to the bufferConn
|
||||
// Take that data and write send a Retry packet
|
||||
replyHdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
ConnectionID: hdr.ConnectionID, // echo the client's connection ID
|
||||
PacketNumber: hdr.PacketNumber, // echo the client's packet number
|
||||
Version: version,
|
||||
}
|
||||
f := &wire.StreamFrame{
|
||||
StreamID: version.CryptoStreamID(),
|
||||
Data: bc.GetDataForWriting(),
|
||||
}
|
||||
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = s.conn.WriteTo(data, remoteAddr)
|
||||
return nil, err
|
||||
}
|
||||
if alert != mint.AlertNoAlert {
|
||||
return nil, alert
|
||||
}
|
||||
if tls.State() != mint.StateServerNegotiated {
|
||||
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
|
||||
}
|
||||
if alert := tls.Handshake(); alert != mint.AlertNoAlert {
|
||||
return nil, alert
|
||||
}
|
||||
if tls.State() != mint.StateServerWaitFlight2 {
|
||||
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
|
||||
}
|
||||
params := <-paramsChan
|
||||
sess, err := newTLSServerSession(
|
||||
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
||||
hdr.ConnectionID, // TODO: we can use a server-chosen connection ID here
|
||||
protocol.PacketNumber(1), // TODO: use a random packet number here
|
||||
s.config,
|
||||
tls,
|
||||
bc,
|
||||
aead,
|
||||
¶ms,
|
||||
version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs := sess.getCryptoStream()
|
||||
cs.SetReadOffset(frame.DataLen())
|
||||
bc.SetStream(cs)
|
||||
return sess, nil
|
||||
}
|
116
server_tls_test.go
Normal file
116
server_tls_test.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
|
||||
|
||||
"github.com/bifurcation/mint"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Stateless TLS handling", func() {
|
||||
var (
|
||||
conn *mockPacketConn
|
||||
server *serverTLS
|
||||
sessionChan <-chan packetHandler
|
||||
mintTLS *mockhandshake.MockMintTLS
|
||||
extHandler *mocks.MockTLSExtensionHandler
|
||||
mintReply io.Writer
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
mintTLS = mockhandshake.NewMockMintTLS(mockCtrl)
|
||||
extHandler = mocks.NewMockTLSExtensionHandler(mockCtrl)
|
||||
conn = newMockPacketConn()
|
||||
config := &Config{
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
var err error
|
||||
server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
|
||||
mintReply = bc
|
||||
return mintTLS, extHandler.GetPeerParams(), nil
|
||||
}
|
||||
})
|
||||
|
||||
getPacket := func(f wire.Frame) (*wire.Header, []byte) {
|
||||
hdrBuf := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
PacketNumber: 1,
|
||||
Version: protocol.VersionTLS,
|
||||
}
|
||||
err := hdr.Write(hdrBuf, protocol.PerspectiveClient, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr.Raw = hdrBuf.Bytes()
|
||||
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, 0, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
buf := &bytes.Buffer{}
|
||||
err = f.Write(buf, protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return hdr, aead.Seal(nil, buf.Bytes(), 1, hdr.Raw)
|
||||
}
|
||||
|
||||
It("sends a version negotiation packet if it doesn't support the version", func() {
|
||||
server.HandleInitial(nil, &wire.Header{Version: 0x1337}, nil)
|
||||
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
||||
hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionUnknown)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hdr.IsVersionNegotiation).To(BeTrue())
|
||||
Expect(sessionChan).ToNot(Receive())
|
||||
})
|
||||
|
||||
It("ignores packets with invalid contents", func() {
|
||||
hdr, data := getPacket(&wire.StreamFrame{StreamID: 10, Offset: 11, Data: []byte("foobar")})
|
||||
server.HandleInitial(nil, hdr, data)
|
||||
Expect(conn.dataWritten.Len()).To(BeZero())
|
||||
Expect(sessionChan).ToNot(Receive())
|
||||
})
|
||||
|
||||
It("replies with a Retry packet, if a Cookie is required", func() {
|
||||
extHandler.EXPECT().GetPeerParams()
|
||||
mintTLS.EXPECT().Handshake().Return(mint.AlertStatelessRetry).Do(func() {
|
||||
mintReply.Write([]byte("Retry with this Cookie"))
|
||||
})
|
||||
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
|
||||
server.HandleInitial(nil, hdr, data)
|
||||
Expect(conn.dataWritten.Len()).ToNot(BeZero())
|
||||
hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionTLS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
|
||||
Expect(sessionChan).ToNot(Receive())
|
||||
})
|
||||
|
||||
It("replies with a Handshake packet and creates a session, if no Cookie is required", func() {
|
||||
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert).Do(func() {
|
||||
mintReply.Write([]byte("Server Hello"))
|
||||
})
|
||||
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert)
|
||||
mintTLS.EXPECT().State().Return(mint.StateServerNegotiated)
|
||||
mintTLS.EXPECT().State().Return(mint.StateServerWaitFlight2)
|
||||
paramsChan := make(chan handshake.TransportParameters, 1)
|
||||
paramsChan <- handshake.TransportParameters{}
|
||||
extHandler.EXPECT().GetPeerParams().Return(paramsChan)
|
||||
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
server.HandleInitial(nil, hdr, data)
|
||||
// the Handshake packet is written by the session
|
||||
Expect(conn.dataWritten.Len()).To(BeZero())
|
||||
close(done)
|
||||
}()
|
||||
Eventually(sessionChan).Should(Receive())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
248
session.go
248
session.go
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
|
||||
"github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
@ -60,7 +61,7 @@ type session struct {
|
|||
conn connection
|
||||
|
||||
streamsMap *streamsMap
|
||||
cryptoStream streamI
|
||||
cryptoStream cryptoStream
|
||||
|
||||
rttStats *congestion.RTTStats
|
||||
|
||||
|
@ -126,21 +127,48 @@ func newSession(
|
|||
conn connection,
|
||||
v protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
sCfg *handshake.ServerConfig,
|
||||
scfg *handshake.ServerConfig,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (packetHandler, error) {
|
||||
paramsChan := make(chan handshake.TransportParameters)
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
config: config,
|
||||
aeadChanged: aeadChanged,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
return s, s.setup(sCfg, "", tlsConf, v, nil)
|
||||
s.preSetup()
|
||||
transportParams := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
MaxStreams: protocol.MaxIncomingStreams,
|
||||
IdleTimeout: s.config.IdleTimeout,
|
||||
}
|
||||
cs, err := newCryptoSetup(
|
||||
s.cryptoStream,
|
||||
s.connectionID,
|
||||
s.conn.RemoteAddr(),
|
||||
s.version,
|
||||
scfg,
|
||||
transportParams,
|
||||
s.config.Versions,
|
||||
s.config.AcceptCookie,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cryptoSetup = cs
|
||||
return s, s.postSetup(1)
|
||||
}
|
||||
|
||||
// declare this as a variable, such that we can it mock it in the tests
|
||||
// declare this as a variable, so that we can it mock it in the tests
|
||||
var newClientSession = func(
|
||||
conn connection,
|
||||
hostname string,
|
||||
|
@ -151,27 +179,130 @@ var newClientSession = func(
|
|||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
|
||||
) (packetHandler, error) {
|
||||
paramsChan := make(chan handshake.TransportParameters)
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
config: config,
|
||||
aeadChanged: aeadChanged,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
return s, s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions)
|
||||
s.preSetup()
|
||||
transportParams := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
MaxStreams: protocol.MaxIncomingStreams,
|
||||
IdleTimeout: s.config.IdleTimeout,
|
||||
OmitConnectionID: s.config.RequestConnectionIDOmission,
|
||||
}
|
||||
cs, err := newCryptoSetupClient(
|
||||
s.cryptoStream,
|
||||
hostname,
|
||||
s.connectionID,
|
||||
s.version,
|
||||
tlsConf,
|
||||
transportParams,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
initialVersion,
|
||||
negotiatedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cryptoSetup = cs
|
||||
return s, s.postSetup(1)
|
||||
}
|
||||
|
||||
func (s *session) setup(
|
||||
scfg *handshake.ServerConfig,
|
||||
hostname string,
|
||||
tlsConf *tls.Config,
|
||||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
) error {
|
||||
func newTLSServerSession(
|
||||
conn connection,
|
||||
connectionID protocol.ConnectionID,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
config *Config,
|
||||
tls handshake.MintTLS,
|
||||
cryptoStreamConn *handshake.CryptoStreamConn,
|
||||
nullAEAD crypto.AEAD,
|
||||
peerParams *handshake.TransportParameters,
|
||||
v protocol.VersionNumber,
|
||||
) (packetHandler, error) {
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
paramsChan := make(chan handshake.TransportParameters)
|
||||
s.aeadChanged = aeadChanged
|
||||
s.paramsChan = paramsChan
|
||||
s := &session{
|
||||
conn: conn,
|
||||
config: config,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
version: v,
|
||||
aeadChanged: aeadChanged,
|
||||
}
|
||||
s.preSetup()
|
||||
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
|
||||
tls,
|
||||
cryptoStreamConn,
|
||||
nullAEAD,
|
||||
aeadChanged,
|
||||
v,
|
||||
)
|
||||
if err := s.postSetup(initialPacketNumber); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.peerParams = peerParams
|
||||
s.processTransportParameters(peerParams)
|
||||
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// declare this as a variable, such that we can it mock it in the tests
|
||||
var newTLSClientSession = func(
|
||||
conn connection,
|
||||
hostname string,
|
||||
v protocol.VersionNumber,
|
||||
connectionID protocol.ConnectionID,
|
||||
config *Config,
|
||||
tls handshake.MintTLS,
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
) (packetHandler, error) {
|
||||
aeadChanged := make(chan protocol.EncryptionLevel, 2)
|
||||
s := &session{
|
||||
conn: conn,
|
||||
config: config,
|
||||
connectionID: connectionID,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
version: v,
|
||||
aeadChanged: aeadChanged,
|
||||
paramsChan: paramsChan,
|
||||
}
|
||||
s.preSetup()
|
||||
tls.SetCryptoStream(s.cryptoStream)
|
||||
cs, err := handshake.NewCryptoSetupTLSClient(
|
||||
s.cryptoStream,
|
||||
s.connectionID,
|
||||
hostname,
|
||||
aeadChanged,
|
||||
tls,
|
||||
v,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cryptoSetup = cs
|
||||
return s, s.postSetup(initialPacketNumber)
|
||||
}
|
||||
|
||||
func (s *session) preSetup() {
|
||||
s.rttStats = &congestion.RTTStats{}
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.ReceiveConnectionFlowControlWindow,
|
||||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||
s.rttStats,
|
||||
)
|
||||
s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream)
|
||||
}
|
||||
|
||||
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
|
||||
s.handshakeChan = make(chan handshakeEvent, 3)
|
||||
s.handshakeCompleteChan = make(chan error, 1)
|
||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||
|
@ -185,91 +316,14 @@ func (s *session) setup(
|
|||
s.lastNetworkActivityTime = now
|
||||
s.sessionCreationTime = now
|
||||
|
||||
s.rttStats = &congestion.RTTStats{}
|
||||
transportParams := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
MaxStreams: protocol.MaxIncomingStreams,
|
||||
IdleTimeout: s.config.IdleTimeout,
|
||||
}
|
||||
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
|
||||
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
|
||||
s.connFlowController = flowcontrol.NewConnectionFlowController(
|
||||
protocol.ReceiveConnectionFlowControlWindow,
|
||||
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
|
||||
s.rttStats,
|
||||
)
|
||||
|
||||
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
|
||||
s.cryptoStream = s.newStream(s.version.CryptoStreamID())
|
||||
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController)
|
||||
|
||||
var err error
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool {
|
||||
return s.config.AcceptCookie(clientAddr, cookie)
|
||||
}
|
||||
if s.version.UsesTLS() {
|
||||
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
|
||||
s.cryptoStream,
|
||||
s.connectionID,
|
||||
tlsConf,
|
||||
s.conn.RemoteAddr(),
|
||||
transportParams,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
verifySourceAddr,
|
||||
s.config.Versions,
|
||||
s.version,
|
||||
)
|
||||
} else {
|
||||
s.cryptoSetup, err = newCryptoSetup(
|
||||
s.cryptoStream,
|
||||
s.connectionID,
|
||||
s.conn.RemoteAddr(),
|
||||
s.version,
|
||||
scfg,
|
||||
transportParams,
|
||||
s.config.Versions,
|
||||
verifySourceAddr,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
|
||||
if s.version.UsesTLS() {
|
||||
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
|
||||
s.cryptoStream,
|
||||
s.connectionID,
|
||||
hostname,
|
||||
tlsConf,
|
||||
transportParams,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
initialVersion,
|
||||
s.config.Versions,
|
||||
s.version,
|
||||
)
|
||||
} else {
|
||||
s.cryptoSetup, err = newCryptoSetupClient(
|
||||
s.cryptoStream,
|
||||
hostname,
|
||||
s.connectionID,
|
||||
s.version,
|
||||
tlsConf,
|
||||
transportParams,
|
||||
paramsChan,
|
||||
aeadChanged,
|
||||
initialVersion,
|
||||
negotiatedVersions,
|
||||
)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.packer = newPacketPacker(s.connectionID,
|
||||
initialPacketNumber,
|
||||
s.cryptoSetup,
|
||||
s.streamFramer,
|
||||
s.perspective,
|
||||
|
@ -604,7 +658,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
|
|||
s.cryptoStream.Cancel(quicErr)
|
||||
s.streamsMap.CloseWithError(quicErr)
|
||||
|
||||
if closeErr.err == errCloseSessionForNewVersion {
|
||||
if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -893,6 +947,10 @@ func (s *session) handshakeStatus() <-chan handshakeEvent {
|
|||
return s.handshakeChan
|
||||
}
|
||||
|
||||
func (s *session) getCryptoStream() cryptoStream {
|
||||
return s.cryptoStream
|
||||
}
|
||||
|
||||
func (s *session) GetVersion() protocol.VersionNumber {
|
||||
return s.version
|
||||
}
|
||||
|
|
|
@ -468,9 +468,6 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
|
||||
It("handles CONNECTION_CLOSE frames", func() {
|
||||
cryptoStream := mocks.NewMockStreamI(mockCtrl)
|
||||
cryptoStream.EXPECT().Cancel(gomock.Any())
|
||||
sess.cryptoStream = cryptoStream
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
|
@ -771,10 +768,15 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
|
||||
Context("sending packets", func() {
|
||||
BeforeEach(func() {
|
||||
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||
})
|
||||
|
||||
It("sends ACK frames", func() {
|
||||
packetNumber := protocol.PacketNumber(0x035e)
|
||||
sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
|
||||
err := sess.sendPacket()
|
||||
err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = sess.sendPacket()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e}))))
|
||||
|
@ -858,6 +860,7 @@ var _ = Describe("Session", func() {
|
|||
BeforeEach(func() {
|
||||
// a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet
|
||||
sess.packer.packetNumberGenerator.next = 0x1337 + 10
|
||||
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||
sph = newMockSentPacketHandler().(*mockSentPacketHandler)
|
||||
sess.sentPacketHandler = sph
|
||||
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
|
||||
|
@ -981,6 +984,7 @@ var _ = Describe("Session", func() {
|
|||
})
|
||||
|
||||
It("retransmits RTO packets", func() {
|
||||
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||
sess.sentPacketHandler.SetHandshakeComplete()
|
||||
n := protocol.PacketNumber(10)
|
||||
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
|
||||
|
@ -1008,6 +1012,7 @@ var _ = Describe("Session", func() {
|
|||
|
||||
Context("scheduling sending", func() {
|
||||
BeforeEach(func() {
|
||||
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||
sess.processTransportParameters(&handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.MaxByteCount,
|
||||
ConnectionFlowControlWindow: protocol.MaxByteCount,
|
||||
|
@ -1291,6 +1296,7 @@ var _ = Describe("Session", func() {
|
|||
sess.handshakeComplete = true
|
||||
sess.config.KeepAlive = true
|
||||
sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2)
|
||||
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
|
||||
go sess.run()
|
||||
defer sess.Close(nil)
|
||||
var data []byte
|
||||
|
@ -1551,7 +1557,10 @@ var _ = Describe("Client Session", func() {
|
|||
})
|
||||
|
||||
It("passes the diversification nonce to the cryptoSetup", func() {
|
||||
go sess.run()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
sess.run()
|
||||
}()
|
||||
hdr.PacketNumber = 5
|
||||
hdr.DiversificationNonce = []byte("foobar")
|
||||
err := sess.handlePacketImpl(&receivedPacket{header: hdr})
|
||||
|
|
13
stream.go
13
stream.go
|
@ -32,6 +32,11 @@ type streamI interface {
|
|||
IsFlowControlBlocked() bool
|
||||
}
|
||||
|
||||
type cryptoStream interface {
|
||||
streamI
|
||||
SetReadOffset(protocol.ByteCount)
|
||||
}
|
||||
|
||||
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
|
||||
//
|
||||
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
|
||||
|
@ -481,3 +486,11 @@ func (s *stream) IsFlowControlBlocked() bool {
|
|||
func (s *stream) GetWindowUpdate() protocol.ByteCount {
|
||||
return s.flowController.GetWindowUpdate()
|
||||
}
|
||||
|
||||
// SetReadOffset sets the read offset.
|
||||
// It is only needed for the crypto stream.
|
||||
// It must not be called concurrently with any other stream methods, especially Read and Write.
|
||||
func (s *stream) SetReadOffset(offset protocol.ByteCount) {
|
||||
s.readOffset = offset
|
||||
s.frameQueue.readPosition = offset
|
||||
}
|
||||
|
|
|
@ -266,6 +266,12 @@ var _ = Describe("Stream", func() {
|
|||
Expect(onDataCalled).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the read offset", func() {
|
||||
str.SetReadOffset(0x42)
|
||||
Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42)))
|
||||
Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42)))
|
||||
})
|
||||
|
||||
Context("deadlines", func() {
|
||||
It("the deadline error has the right net.Error properties", func() {
|
||||
Expect(errDeadline.Temporary()).To(BeTrue())
|
||||
|
|
|
@ -325,4 +325,5 @@ func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
|
|||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.maxOutgoingStreams = limit
|
||||
m.openStreamOrErrCond.Broadcast()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue