implement stateless handling of Initial packets for the TLS server

This commit is contained in:
Marten Seemann 2017-11-11 08:38:27 +08:00
parent 57c6f3ceb5
commit 25a6dc9654
36 changed files with 1617 additions and 724 deletions

156
client.go
View file

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
@ -23,14 +24,18 @@ type client struct {
hostname string hostname string
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has version negotiation completed yet versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config tlsConf *tls.Config
config *Config config *Config
tls handshake.MintTLS // only used when using TLS
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
version protocol.VersionNumber
initialVersion protocol.VersionNumber
version protocol.VersionNumber
session packetHandler session packetHandler
} }
@ -91,7 +96,6 @@ func DialNonFWSecure(
if tlsConf != nil { if tlsConf != nil {
hostname = tlsConf.ServerName hostname = tlsConf.ServerName
} }
if hostname == "" { if hostname == "" {
hostname, _, err = net.SplitHostPort(host) hostname, _, err = net.SplitHostPort(host)
if err != nil { if err != nil {
@ -111,8 +115,9 @@ func DialNonFWSecure(
} }
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version) utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
go c.listen()
if err := c.establishSecureConnection(); err != nil { if err := c.dial(); err != nil {
return nil, err return nil, err
} }
return c.session.(NonFWSession), nil return c.session.(NonFWSession), nil
@ -177,25 +182,79 @@ func populateClientConfig(config *Config) *Config {
} }
} }
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure) func (c *client) dial() error {
func (c *client) establishSecureConnection() error { var err error
if err := c.createNewSession(c.version, nil); err != nil { if c.version.UsesTLS() {
err = c.dialTLS()
} else {
err = c.dialGQUIC()
}
if err == errCloseSessionForNewVersion {
return c.dial()
}
return err
}
func (c *client) dialGQUIC() error {
if err := c.createNewGQUICSession(); err != nil {
return err return err
} }
go c.listen() return c.establishSecureConnection()
}
func (c *client) dialTLS() error {
csc := handshake.NewCryptoStreamConn(nil)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ServerName = c.hostname
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
}
eh := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
if err := c.tls.SetExtensionHandler(eh); err != nil {
return err
}
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {
if err != handshake.ErrCloseSessionForRetry {
return err
}
utils.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {
return err
}
}
return nil
}
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error {
var runErr error var runErr error
errorChan := make(chan struct{}) errorChan := make(chan struct{})
go func() { go func() {
// session.run() returns as soon as the session is closed runErr = c.session.run() // returns as soon as the session is closed
runErr = c.session.run()
if runErr == errCloseSessionForNewVersion {
// run the new session
runErr = c.session.run()
}
close(errorChan) close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID) utils.Infof("Connection %x closed.", c.connectionID)
c.conn.Close() if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close()
}
}() }()
// wait until the server accepts the QUIC version (or an error occurs) // wait until the server accepts the QUIC version (or an error occurs)
@ -219,7 +278,8 @@ func (c *client) establishSecureConnection() error {
} }
} }
// Listen listens // Listen listens on the underlying connection and passes packets on for handling.
// It returns when the connection is closed.
func (c *client) listen() { func (c *client) listen() {
var err error var err error
@ -233,13 +293,15 @@ func (c *client) listen() {
n, addr, err = c.conn.Read(data) n, addr, err = c.conn.Read(data)
if err != nil { if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") { if !strings.HasSuffix(err.Error(), "use of closed network connection") {
c.session.Close(err) c.mutex.Lock()
if c.session != nil {
c.session.Close(err)
}
c.mutex.Unlock()
} }
break break
} }
data = data[:n] c.handlePacket(addr, data[:n])
c.handlePacket(addr, data)
} }
} }
@ -257,15 +319,16 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission { if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return return
} }
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
}
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
}
if hdr.ResetFlag { if hdr.ResetFlag {
cr := c.conn.RemoteAddr() cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match // check if the remote address and the connection ID match
@ -305,6 +368,8 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
close(c.versionNegotiationChan) close(c.versionNegotiationChan)
} }
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
c.session.handlePacket(&receivedPacket{ c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
header: hdr, header: hdr,
@ -323,15 +388,15 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
} }
} }
c.receivedVersionNegotiationPacket = true
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok { if !ok {
return qerr.InvalidVersion return qerr.InvalidVersion
} }
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version // switch to negotiated version
initialVersion := c.version c.initialVersion = c.version
c.version = newVersion c.version = newVersion
var err error var err error
c.connectionID, err = utils.GenerateConnectionID() c.connectionID, err = utils.GenerateConnectionID()
@ -339,17 +404,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
return err return err
} }
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID) utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
// create a new session and close the old one return nil
// the new session must be created first to update client member variables
oldSession := c.session
defer oldSession.Close(errCloseSessionForNewVersion)
return c.createNewSession(initialVersion, hdr.SupportedVersions)
} }
func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error { func (c *client) createNewGQUICSession() (err error) {
var err error c.mutex.Lock()
utils.Debugf("createNewSession with initial version %s", initialVersion) defer c.mutex.Unlock()
c.session, err = newClientSession( c.session, err = newClientSession(
c.conn, c.conn,
c.hostname, c.hostname,
@ -357,8 +418,27 @@ func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotia
c.connectionID, c.connectionID,
c.tlsConf, c.tlsConf,
c.config, c.config,
initialVersion, c.initialVersion,
negotiatedVersions, c.negotiatedVersions,
)
return err
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newTLSClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.config,
c.tls,
paramsChan,
1,
) )
return err return err
} }

View file

@ -8,6 +8,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
@ -45,10 +46,9 @@ var _ = Describe("Client", func() {
msess, _ := newMockSession(nil, 0, 0, nil, nil, nil) msess, _ := newMockSession(nil, 0, 0, nil, nil, nil)
sess = msess.(*mockSession) sess = msess.(*mockSession)
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = &mockPacketConn{ packetConn = newMockPacketConn()
addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}, packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
dataReadFrom: addr, packetConn.dataReadFrom = addr
}
config = &Config{ config = &Config{
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78}, Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
} }
@ -87,7 +87,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber, _ protocol.VersionNumber,
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
) (packetHandler, error) { ) (packetHandler, error) {
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed()) Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
return sess, nil return sess, nil
} }
origGenerateConnectionID = generateConnectionID origGenerateConnectionID = generateConnectionID
@ -101,7 +101,7 @@ var _ = Describe("Client", func() {
}) })
It("dials non-forward-secure", func() { It("dials non-forward-secure", func() {
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{}) dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -151,7 +151,7 @@ var _ = Describe("Client", func() {
}) })
It("Dial only returns after the handshake is complete", func() { It("Dial only returns after the handshake is complete", func() {
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{}) dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -237,7 +237,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the connection to become secure", func() { It("returns an error that occurs while waiting for the connection to become secure", func() {
testErr := errors.New("early handshake error") testErr := errors.New("early handshake error")
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -251,7 +251,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the handshake to complete", func() { It("returns an error that occurs while waiting for the handshake to complete", func() {
testErr := errors.New("late handshake error") testErr := errors.New("late handshake error")
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID) packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -330,10 +330,6 @@ var _ = Describe("Client", func() {
newVersion := protocol.VersionNumber(77) newVersion := protocol.VersionNumber(77)
Expect(newVersion).ToNot(Equal(cl.version)) Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion)) Expect(config.Versions).To(ContainElement(newVersion))
packetConn.dataToRead = wire.ComposeGQUICVersionNegotiation(
cl.connectionID,
[]protocol.VersionNumber{newVersion},
)
sessionChan := make(chan *mockSession) sessionChan := make(chan *mockSession)
handshakeChan := make(chan handshakeEvent) handshakeChan := make(chan handshakeEvent)
newClientSession = func( newClientSession = func(
@ -348,10 +344,7 @@ var _ = Describe("Client", func() {
) (packetHandler, error) { ) (packetHandler, error) {
initialVersion = initialVersionP initialVersion = initialVersionP
negotiatedVersions = negotiatedVersionsP negotiatedVersions = negotiatedVersionsP
// make the server accept the new version
if len(negotiatedVersionsP) > 0 {
packetConn.dataToRead = acceptClientVersionPacket(connectionID)
}
sess := &mockSession{ sess := &mockSession{
connectionID: connectionID, connectionID: connectionID,
stopRunLoop: make(chan struct{}), stopRunLoop: make(chan struct{}),
@ -364,18 +357,26 @@ var _ = Describe("Client", func() {
established := make(chan struct{}) established := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := cl.establishSecureConnection() err := cl.dial()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
close(established) close(established)
}() }()
go cl.listen()
actualInitialVersion := cl.version actualInitialVersion := cl.version
var firstSession, secondSession *mockSession var firstSession, secondSession *mockSession
Eventually(sessionChan).Should(Receive(&firstSession)) Eventually(sessionChan).Should(Receive(&firstSession))
Eventually(sessionChan).Should(Receive(&secondSession)) packetConn.dataToRead <- wire.ComposeGQUICVersionNegotiation(
cl.connectionID,
[]protocol.VersionNumber{newVersion},
)
// it didn't pass the version negoation packet to the old session (since it has no payload) // it didn't pass the version negoation packet to the old session (since it has no payload)
Expect(firstSession.packetCount).To(BeZero())
Eventually(func() bool { return firstSession.closed }).Should(BeTrue()) Eventually(func() bool { return firstSession.closed }).Should(BeTrue())
Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion)) Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion))
Expect(firstSession.packetCount).To(BeZero())
Eventually(sessionChan).Should(Receive(&secondSession))
// make the server accept the new version
packetConn.dataToRead <- acceptClientVersionPacket(secondSession.connectionID)
Consistently(func() bool { return secondSession.closed }).Should(BeFalse()) Consistently(func() bool { return secondSession.closed }).Should(BeFalse())
Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337)) Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337))
Expect(negotiatedVersions).To(ContainElement(newVersion)) Expect(negotiatedVersions).To(ContainElement(newVersion))
@ -398,20 +399,23 @@ var _ = Describe("Client", func() {
_ []protocol.VersionNumber, _ []protocol.VersionNumber,
) (packetHandler, error) { ) (packetHandler, error) {
atomic.AddUint32(&sessionCounter, 1) atomic.AddUint32(&sessionCounter, 1)
return sess, nil return &mockSession{
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
}, nil
} }
go cl.establishSecureConnection() go cl.dial()
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1)) Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1))
newVersion := protocol.VersionNumber(77) newVersion := protocol.VersionNumber(77)
Expect(newVersion).ToNot(Equal(cl.version)) Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion)) Expect(config.Versions).To(ContainElement(newVersion))
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2)) Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
newVersion = protocol.VersionNumber(78) newVersion = protocol.VersionNumber(78)
Expect(newVersion).ToNot(Equal(cl.version)) Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion)) Expect(config.Versions).To(ContainElement(newVersion))
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2)) Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
}) })
It("errors if no matching version is found", func() { It("errors if no matching version is found", func() {
@ -482,7 +486,7 @@ var _ = Describe("Client", func() {
Expect(sess.closed).To(BeFalse()) Expect(sess.closed).To(BeFalse())
}) })
It("creates new sessions with the right parameters", func() { It("creates new GQUIC sessions with the right parameters", func() {
closeErr := errors.New("peer doesn't reply") closeErr := errors.New("peer doesn't reply")
c := make(chan struct{}) c := make(chan struct{})
var cconn connection var cconn connection
@ -516,12 +520,84 @@ var _ = Describe("Client", func() {
Eventually(c).Should(BeClosed()) Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn)) Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(hostname).To(Equal("quic.clemente.io")) Expect(hostname).To(Equal("quic.clemente.io"))
Expect(version).To(Equal(cl.version)) Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions)) Expect(conf.Versions).To(Equal(config.Versions))
sess.Close(closeErr) sess.Close(closeErr)
Eventually(dialed).Should(BeClosed()) Eventually(dialed).Should(BeClosed())
}) })
It("creates new TLS sessions with the right parameters", func() {
config.Versions = []protocol.VersionNumber{protocol.VersionTLS}
c := make(chan struct{})
var cconn connection
var hostname string
var version protocol.VersionNumber
var conf *Config
newTLSClientSession = func(
connP connection,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
configP *Config,
tls handshake.MintTLS,
paramsChan <-chan handshake.TransportParameters,
_ protocol.PacketNumber,
) (packetHandler, error) {
cconn = connP
hostname = hostnameP
version = versionP
conf = configP
close(c)
return sess, nil
}
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
close(dialed)
}()
Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(hostname).To(Equal("quic.clemente.io"))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
sess.Close(errors.New("peer doesn't reply"))
Eventually(dialed).Should(BeClosed())
})
It("creates a new session when the server performs a retry", func() {
config.Versions = []protocol.VersionNumber{protocol.VersionTLS}
sessionChan := make(chan *mockSession)
newTLSClientSession = func(
connP connection,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
configP *Config,
tls handshake.MintTLS,
paramsChan <-chan handshake.TransportParameters,
_ protocol.PacketNumber,
) (packetHandler, error) {
sess := &mockSession{
stopRunLoop: make(chan struct{}),
}
sessionChan <- sess
return sess, nil
}
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
close(dialed)
}()
var firstSession, secondSession *mockSession
Eventually(sessionChan).Should(Receive(&firstSession))
firstSession.Close(handshake.ErrCloseSessionForRetry)
Eventually(sessionChan).Should(Receive(&secondSession))
secondSession.Close(errors.New("stop test"))
Eventually(dialed).Should(BeClosed())
})
Context("handling packets", func() { Context("handling packets", func() {
It("handles packets", func() { It("handles packets", func() {
ph := wire.Header{ ph := wire.Header{
@ -532,7 +608,7 @@ var _ = Describe("Client", func() {
b := &bytes.Buffer{} b := &bytes.Buffer{}
err := ph.Write(b, protocol.PerspectiveServer, cl.version) err := ph.Write(b, protocol.PerspectiveServer, cl.version)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
packetConn.dataToRead = b.Bytes() packetConn.dataToRead <- b.Bytes()
Expect(sess.packetCount).To(BeZero()) Expect(sess.packetCount).To(BeZero())
stoppedListening := make(chan struct{}) stoppedListening := make(chan struct{})

View file

@ -2,7 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"io" "errors"
"net" "net"
"time" "time"
@ -12,7 +12,7 @@ import (
type mockPacketConn struct { type mockPacketConn struct {
addr net.Addr addr net.Addr
dataToRead []byte dataToRead chan []byte
dataReadFrom net.Addr dataReadFrom net.Addr
readErr error readErr error
dataWritten bytes.Buffer dataWritten bytes.Buffer
@ -20,23 +20,34 @@ type mockPacketConn struct {
closed bool closed bool
} }
func newMockPacketConn() *mockPacketConn {
return &mockPacketConn{
dataToRead: make(chan []byte, 1000),
}
}
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
if c.readErr != nil { if c.readErr != nil {
return 0, nil, c.readErr return 0, nil, c.readErr
} }
if c.dataToRead == nil { // block if there's no data data, ok := <-c.dataToRead
time.Sleep(time.Hour) if !ok {
return 0, nil, io.EOF return 0, nil, errors.New("connection closed")
} }
n := copy(b, c.dataToRead) n := copy(b, data)
c.dataToRead = nil
return n, c.dataReadFrom, nil return n, c.dataReadFrom, nil
} }
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
c.dataWrittenTo = addr c.dataWrittenTo = addr
return c.dataWritten.Write(b) return c.dataWritten.Write(b)
} }
func (c *mockPacketConn) Close() error { c.closed = true; return nil } func (c *mockPacketConn) Close() error {
if !c.closed {
close(c.dataToRead)
}
c.closed = true
return nil
}
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr } func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") } func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") } func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
@ -53,7 +64,7 @@ var _ = Describe("Connection", func() {
IP: net.IPv4(192, 168, 100, 200), IP: net.IPv4(192, 168, 100, 200),
Port: 1337, Port: 1337,
} }
packetConn = &mockPacketConn{} packetConn = newMockPacketConn()
c = &conn{ c = &conn{
currentAddr: addr, currentAddr: addr,
pconn: packetConn, pconn: packetConn,
@ -68,7 +79,7 @@ var _ = Describe("Connection", func() {
}) })
It("reads", func() { It("reads", func() {
packetConn.dataToRead = []byte("foo") packetConn.dataToRead <- []byte("foo")
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336} packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
p := make([]byte, 10) p := make([]byte, 10)
n, raddr, err := c.Read(p) n, raddr, err := c.Read(p)

View file

@ -7,33 +7,33 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type cookieHandler struct { type CookieHandler struct {
callback func(net.Addr, *Cookie) bool callback func(net.Addr, *Cookie) bool
cookieGenerator *CookieGenerator cookieGenerator *CookieGenerator
} }
var _ mint.CookieHandler = &cookieHandler{} var _ mint.CookieHandler = &CookieHandler{}
func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) { func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
cookieGenerator, err := NewCookieGenerator() cookieGenerator, err := NewCookieGenerator()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cookieHandler{ return &CookieHandler{
callback: callback, callback: callback,
cookieGenerator: cookieGenerator, cookieGenerator: cookieGenerator,
}, nil }, nil
} }
func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) { func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
if h.callback(conn.RemoteAddr(), nil) { if h.callback(conn.RemoteAddr(), nil) {
return nil, nil return nil, nil
} }
return h.cookieGenerator.NewToken(conn.RemoteAddr()) return h.cookieGenerator.NewToken(conn.RemoteAddr())
} }
func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool { func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
data, err := h.cookieGenerator.DecodeToken(token) data, err := h.cookieGenerator.DecodeToken(token)
if err != nil { if err != nil {
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error()) utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())

View file

@ -2,6 +2,7 @@ package handshake
import ( import (
"net" "net"
"time"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
@ -9,22 +10,37 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type mockConn struct {
remoteAddr net.Addr
}
var _ net.Conn = &mockConn{}
func (c *mockConn) Read([]byte) (int, error) { panic("not implemented") }
func (c *mockConn) Write([]byte) (int, error) { panic("not implemented") }
func (c *mockConn) Close() error { panic("not implemented") }
func (c *mockConn) LocalAddr() net.Addr { panic("not implemented") }
func (c *mockConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *mockConn) SetReadDeadline(time.Time) error { panic("not implemented") }
func (c *mockConn) SetWriteDeadline(time.Time) error { panic("not implemented") }
func (c *mockConn) SetDeadline(time.Time) error { panic("not implemented") }
var callbackReturn bool var callbackReturn bool
var mockCallback = func(net.Addr, *Cookie) bool { var mockCallback = func(net.Addr, *Cookie) bool {
return callbackReturn return callbackReturn
} }
var _ = Describe("Cookie Handler", func() { var _ = Describe("Cookie Handler", func() {
var ch *cookieHandler var ch *CookieHandler
var conn *mint.Conn var conn *mint.Conn
BeforeEach(func() { BeforeEach(func() {
callbackReturn = false callbackReturn = false
var err error var err error
ch, err = newCookieHandler(mockCallback) ch, err = NewCookieHandler(mockCallback)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46} addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46}
conn = mint.NewConn(&fakeConn{remoteAddr: addr}, &mint.Config{}, false) conn = mint.NewConn(&mockConn{remoteAddr: addr}, &mint.Config{}, false)
}) })
It("generates and validates a token", func() { It("generates and validates a token", func() {

View file

@ -381,10 +381,6 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data h.divNonceChan <- data
} }
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupClient) sendCHLO() error { func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++ h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos { if h.clientHelloCounter > protocol.MaxClientHellos {

View file

@ -458,10 +458,6 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer") panic("not needed for cryptoSetupServer")
} }
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 { if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")

View file

@ -1,10 +1,9 @@
package handshake package handshake
import ( import (
"crypto/tls" "errors"
"fmt" "fmt"
"io" "io"
"net"
"sync" "sync"
"github.com/bifurcation/mint" "github.com/bifurcation/mint"
@ -12,6 +11,9 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
// KeyDerivationFunction is used for key derivation // KeyDerivationFunction is used for key derivation
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
@ -20,68 +22,31 @@ type cryptoSetupTLS struct {
perspective protocol.Perspective perspective protocol.Perspective
tls mintTLS
conn *fakeConn
nextPacketType protocol.PacketType
keyDerivation KeyDerivationFunction keyDerivation KeyDerivationFunction
nullAEAD crypto.AEAD nullAEAD crypto.AEAD
aead crypto.AEAD aead crypto.AEAD
aeadChanged chan<- protocol.EncryptionLevel tls MintTLS
cryptoStream *CryptoStreamConn
aeadChanged chan<- protocol.EncryptionLevel
} }
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer( func NewCryptoSetupTLSServer(
cryptoStream io.ReadWriter, tls MintTLS,
connID protocol.ConnectionID, cryptoStream *CryptoStreamConn,
tlsConfig *tls.Config, nullAEAD crypto.AEAD,
remoteAddr net.Addr,
params *TransportParameters,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel, aeadChanged chan<- protocol.EncryptionLevel,
checkCookie func(net.Addr, *Cookie) bool,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetup, error) { ) CryptoSetup {
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
if err != nil {
return nil, err
}
mintConf.RequireCookie = true
mintConf.CookieHandler, err = newCookieHandler(checkCookie)
if err != nil {
return nil, err
}
mintConf.CookieProtector, err = mint.NewDefaultCookieProtector()
if err != nil {
return nil, err
}
conn := &fakeConn{
stream: cryptoStream,
pers: protocol.PerspectiveServer,
remoteAddr: remoteAddr,
}
mintConn := mint.Server(conn, mintConf)
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
}
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupTLS{ return &cryptoSetupTLS{
perspective: protocol.PerspectiveServer, tls: tls,
tls: &mintController{mintConn}, cryptoStream: cryptoStream,
conn: conn,
nullAEAD: nullAEAD, nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys, keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged, aeadChanged: aeadChanged,
}, nil }
} }
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client // NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
@ -89,60 +54,44 @@ func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter, cryptoStream io.ReadWriter,
connID protocol.ConnectionID, connID protocol.ConnectionID,
hostname string, hostname string,
tlsConfig *tls.Config,
params *TransportParameters,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel, aeadChanged chan<- protocol.EncryptionLevel,
initialVersion protocol.VersionNumber, tls MintTLS,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
) (CryptoSetup, error) { ) (CryptoSetup, error) {
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
if err != nil {
return nil, err
}
mintConf.ServerName = hostname
conn := &fakeConn{
stream: cryptoStream,
pers: protocol.PerspectiveClient,
}
mintConn := mint.Client(conn, mintConf)
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
}
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &cryptoSetupTLS{ return &cryptoSetupTLS{
conn: conn, perspective: protocol.PerspectiveClient,
perspective: protocol.PerspectiveClient, tls: tls,
tls: &mintController{mintConn}, nullAEAD: nullAEAD,
nullAEAD: nullAEAD, keyDerivation: crypto.DeriveAESKeys,
keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged,
aeadChanged: aeadChanged,
nextPacketType: protocol.PacketTypeInitial,
}, nil }, nil
} }
func (h *cryptoSetupTLS) HandleCryptoStream() error { func (h *cryptoSetupTLS) HandleCryptoStream() error {
if h.perspective == protocol.PerspectiveServer {
// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
// send out that data now
if _, err := h.cryptoStream.Flush(); err != nil {
return err
}
}
handshakeLoop: handshakeLoop:
for { for {
switch alert := h.tls.Handshake(); alert { if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
case mint.AlertStatelessRetry:
case mint.AlertNoAlert: // handshake complete
break handshakeLoop
case mint.AlertWouldBlock:
h.determineNextPacketType()
if err := h.conn.Continue(); err != nil {
return err
}
default:
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
} }
switch h.tls.State() {
case mint.StateClientStart: // this happens if a stateless retry is performed
return ErrCloseSessionForRetry
case mint.StateClientConnected, mint.StateServerConnected:
break handshakeLoop
}
} }
aead, err := h.keyDerivation(h.tls, h.perspective) aead, err := h.keyDerivation(h.tls, h.perspective)
@ -209,35 +158,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
return protocol.EncryptionUnencrypted, h.nullAEAD return protocol.EncryptionUnencrypted, h.nullAEAD
} }
func (h *cryptoSetupTLS) determineNextPacketType() error {
h.mutex.Lock()
defer h.mutex.Unlock()
state := h.tls.State().HandshakeState
if h.perspective == protocol.PerspectiveServer {
switch state {
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest
h.nextPacketType = protocol.PacketTypeRetry
case "ServerStateWaitFinished":
h.nextPacketType = protocol.PacketTypeHandshake
default:
// TODO: accept 0-RTT data
return fmt.Errorf("Unexpected handshake state: %s", state)
}
return nil
}
// client
if state != "ClientStateWaitSH" {
h.nextPacketType = protocol.PacketTypeHandshake
}
return nil
}
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.nextPacketType
}
func (h *cryptoSetupTLS) DiversificationNonce() []byte { func (h *cryptoSetupTLS) DiversificationNonce() []byte {
panic("diversification nonce not needed for TLS") panic("diversification nonce not needed for TLS")
} }

View file

@ -1,7 +1,6 @@
package handshake package handshake
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
@ -10,7 +9,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/mocks/crypto" "github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake" "github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
@ -23,52 +21,33 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
var _ = Describe("TLS Crypto Setup", func() { var _ = Describe("TLS Crypto Setup", func() {
var ( var (
cs *cryptoSetupTLS cs *cryptoSetupTLS
paramsChan chan TransportParameters
aeadChanged chan protocol.EncryptionLevel aeadChanged chan protocol.EncryptionLevel
) )
BeforeEach(func() { BeforeEach(func() {
paramsChan = make(chan TransportParameters)
aeadChanged = make(chan protocol.EncryptionLevel, 2) aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupTLSServer( cs = NewCryptoSetupTLSServer(
nil, nil,
1, NewCryptoStreamConn(nil),
testdata.GetTLSConfig(), nil, // AEAD
nil,
&TransportParameters{},
paramsChan,
aeadChanged, aeadChanged,
nil,
nil,
protocol.VersionTLS, protocol.VersionTLS,
) ).(*cryptoSetupTLS)
Expect(err).ToNot(HaveOccurred())
cs = csInt.(*cryptoSetupTLS)
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
}) })
It("errors when the handshake fails", func() { It("errors when the handshake fails", func() {
alert := mint.AlertBadRecordMAC alert := mint.AlertBadRecordMAC
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(alert) cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(alert)
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert))) Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
}) })
It("continues shaking hands when mint says that it would block", func() {
cs.conn.stream = &bytes.Buffer{}
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{})
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
})
It("derives keys", func() { It("derives keys", func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -76,64 +55,22 @@ var _ = Describe("TLS Crypto Setup", func() {
Expect(aeadChanged).To(BeClosed()) Expect(aeadChanged).To(BeClosed())
}) })
Context("determining the packet type", func() { It("handshakes until it is connected", func() {
Context("for the client", func() { cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
var csClient *cryptoSetupTLS cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerNegotiated).Times(9)
BeforeEach(func() { cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
csInt, err := NewCryptoSetupTLSClient( cs.keyDerivation = mockKeyDerivation
nil, err := cs.HandleCryptoStream()
1, Expect(err).ToNot(HaveOccurred())
"quic.clemente.io", Expect(aeadChanged).To(Receive())
testdata.GetTLSConfig(),
&TransportParameters{},
paramsChan,
aeadChanged,
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
)
Expect(err).ToNot(HaveOccurred())
csClient = csInt.(*cryptoSetupTLS)
csClient.tls = mockhandshake.NewMockmintTLS(mockCtrl)
})
It("sends a Client Initial first", func() {
Expect(csClient.GetNextPacketType()).To(Equal(protocol.PacketTypeInitial))
})
It("sends a Handshake packet after the server sent a Server Hello", func() {
csClient.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ClientStateWaitEE"})
err := csClient.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
})
})
Context("for the server", func() {
BeforeEach(func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
})
It("sends a Stateless Retry packet", func() {
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateStart"})
err := cs.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeRetry))
})
It("sends Handshake packet", func() {
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateWaitFinished"})
err := cs.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeHandshake))
})
})
}) })
Context("escalating crypto", func() { Context("escalating crypto", func() {
doHandshake := func() { doHandshake := func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.keyDerivation = mockKeyDerivation cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream() err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -240,3 +177,33 @@ var _ = Describe("TLS Crypto Setup", func() {
}) })
}) })
}) })
var _ = Describe("TLS Crypto Setup, for the client", func() {
var (
cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel
)
BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupTLSClient(
nil,
0,
"quic.clemente.io",
aeadChanged,
nil, // mintTLS
protocol.VersionTLS,
)
Expect(err).ToNot(HaveOccurred())
cs = csInt.(*cryptoSetupTLS)
})
It("returns when a retry is performed", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateClientStart)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(ErrCloseSessionForRetry))
})
})

View file

@ -0,0 +1,101 @@
package handshake
import (
"bytes"
"io"
"net"
"time"
)
// The CryptoStreamConn is used as the net.Conn passed to mint.
// It has two operating modes:
// 1. It can read and write to bytes.Buffers.
// 2. It can use a quic.Stream for reading and writing.
// The buffer-mode is only used by the server, in order to statelessly handle retries.
type CryptoStreamConn struct {
remoteAddr net.Addr
// the buffers are used before the session is initialized
readBuf bytes.Buffer
writeBuf bytes.Buffer
// stream will be set once the session is initialized
stream io.ReadWriter
}
var _ net.Conn = &CryptoStreamConn{}
// NewCryptoStreamConn creates a new CryptoStreamConn
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
return &CryptoStreamConn{remoteAddr: remoteAddr}
}
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
if c.stream != nil {
return c.stream.Read(b)
}
return c.readBuf.Read(b)
}
// AddDataForReading adds data to the read buffer.
// This data will ONLY be read when the stream has not been set.
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
c.readBuf.Write(data)
}
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
if c.stream != nil {
return c.stream.Write(p)
}
return c.writeBuf.Write(p)
}
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
func (c *CryptoStreamConn) GetDataForWriting() []byte {
defer c.writeBuf.Reset()
data := make([]byte, c.writeBuf.Len())
copy(data, c.writeBuf.Bytes())
return data
}
// SetStream sets the stream.
// After setting the stream, the read and write buffer won't be used any more.
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
c.stream = stream
}
// Flush copies the contents of the write buffer to the stream
func (c *CryptoStreamConn) Flush() (int, error) {
n, err := io.Copy(c.stream, &c.writeBuf)
return int(n), err
}
// Close is not implemented
func (c *CryptoStreamConn) Close() error {
return nil
}
// LocalAddr is not implemented
func (c *CryptoStreamConn) LocalAddr() net.Addr {
return nil
}
// RemoteAddr returns the remote address
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
// SetReadDeadline is not implemented
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
return nil
}
// SetWriteDeadline is not implemented
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// SetDeadline is not implemented
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
return nil
}

View file

@ -0,0 +1,67 @@
package handshake
import (
"bytes"
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("CryptoStreamConn", func() {
var (
csc *CryptoStreamConn
remoteAddr net.Addr
)
BeforeEach(func() {
remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
csc = NewCryptoStreamConn(remoteAddr)
})
It("reads from the read buffer, when no stream is set", func() {
csc.AddDataForReading([]byte("foobar"))
data := make([]byte, 4)
n, err := csc.Read(data)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(data).To(Equal([]byte("foob")))
})
It("writes to the write buffer, when no stream is set", func() {
csc.Write([]byte("foo"))
Expect(csc.GetDataForWriting()).To(Equal([]byte("foo")))
csc.Write([]byte("bar"))
Expect(csc.GetDataForWriting()).To(Equal([]byte("bar")))
})
It("reads from the stream, if available", func() {
csc.stream = &bytes.Buffer{}
csc.stream.Write([]byte("foobar"))
data := make([]byte, 3)
n, err := csc.Read(data)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(data).To(Equal([]byte("foo")))
})
It("writes to the stream, if available", func() {
stream := &bytes.Buffer{}
csc.SetStream(stream)
csc.Write([]byte("foobar"))
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
It("returns the remote address", func() {
Expect(csc.RemoteAddr()).To(Equal(remoteAddr))
})
It("has unimplemented methods", func() {
Expect(csc.Close()).ToNot(HaveOccurred())
Expect(csc.SetDeadline(time.Time{})).ToNot(HaveOccurred())
Expect(csc.SetReadDeadline(time.Time{})).ToNot(HaveOccurred())
Expect(csc.SetWriteDeadline(time.Time{})).ToNot(HaveOccurred())
Expect(csc.LocalAddr()).To(BeNil())
})
})

View file

@ -1,6 +1,10 @@
package handshake package handshake
import ( import (
"io"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
) )
@ -10,14 +14,33 @@ type Sealer interface {
Overhead() int Overhead() int
} }
// A TLSExtensionHandler sends and received the QUIC TLS extension.
// It provides the parameters sent by the peer on a channel.
type TLSExtensionHandler interface {
Send(mint.HandshakeType, *mint.ExtensionList) error
Receive(mint.HandshakeType, *mint.ExtensionList) error
GetPeerParams() <-chan TransportParameters
}
// MintTLS combines some methods needed to interact with mint.
type MintTLS interface {
crypto.TLSExporter
// additional methods
Handshake() mint.Alert
State() mint.State
SetCryptoStream(io.ReadWriter)
SetExtensionHandler(mint.AppExtensionHandler) error
}
// CryptoSetup is a crypto setup // CryptoSetup is a crypto setup
type CryptoSetup interface { type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
HandleCryptoStream() error HandleCryptoStream() error
// TODO: clean up this interface // TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)

View file

@ -1,127 +0,0 @@
package handshake
import (
"bytes"
gocrypto "crypto"
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
mconf := &mint.Config{
NonBlocking: true,
CipherSuites: []mint.CipherSuite{
mint.TLS_AES_128_GCM_SHA256,
mint.TLS_AES_256_GCM_SHA384,
},
}
if tlsConf != nil {
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
for i, certChain := range tlsConf.Certificates {
mconf.Certificates[i] = &mint.Certificate{
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
}
for j, cert := range certChain.Certificate {
c, err := x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
mconf.Certificates[i].Chain[j] = c
}
}
}
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
return nil, err
}
return mconf, nil
}
type mintTLS interface {
// These two methods are the same as the crypto.TLSExporter interface.
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
GetCipherSuite() mint.CipherSuiteParams
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
// additional methods
Handshake() mint.Alert
State() mint.ConnectionState
}
var _ crypto.TLSExporter = (mintTLS)(nil)
type mintController struct {
conn *mint.Conn
}
var _ mintTLS = &mintController{}
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
return mc.conn.State().CipherSuite
}
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
return mc.conn.ComputeExporter(label, context, keyLength)
}
func (mc *mintController) Handshake() mint.Alert {
return mc.conn.Handshake()
}
func (mc *mintController) State() mint.ConnectionState {
return mc.conn.State()
}
// mint expects a net.Conn, but we're doing the handshake on a stream
// so we wrap a stream such that implements a net.Conn
type fakeConn struct {
stream io.ReadWriter
pers protocol.Perspective
remoteAddr net.Addr
blockRead bool
writeBuffer bytes.Buffer
}
var _ net.Conn = &fakeConn{}
func (c *fakeConn) Read(b []byte) (int, error) {
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
return 0, nil
}
c.blockRead = true // block the next Read call
return c.stream.Read(b)
}
func (c *fakeConn) Write(p []byte) (int, error) {
if c.pers == protocol.PerspectiveClient {
return c.stream.Write(p)
}
// Buffer all writes by the server.
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
return c.writeBuffer.Write(p)
}
func (c *fakeConn) Continue() error {
c.blockRead = false
if c.pers == protocol.PerspectiveClient {
return nil
}
// write all contents of the write buffer to the stream.
_, err := c.stream.Write(c.writeBuffer.Bytes())
c.writeBuffer.Reset()
return err
}
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) LocalAddr() net.Addr { return nil }
func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
func (c *fakeConn) SetDeadline(time.Time) error { return nil }

View file

@ -1,72 +0,0 @@
package handshake
import (
"bytes"
"net"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Fake Conn", func() {
var (
c *fakeConn
stream *bytes.Buffer
)
BeforeEach(func() {
stream = &bytes.Buffer{}
c = &fakeConn{stream: stream}
})
Context("Reading", func() {
It("doesn't return any new data after one Read call", func() {
stream.Write([]byte("foobar"))
b := make([]byte, 3)
_, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("foo")))
n, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(BeZero())
})
It("allows more Read calls after unblocking", func() {
stream.Write([]byte("foobar"))
b := make([]byte, 3)
_, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
err = c.Continue()
Expect(err).ToNot(HaveOccurred())
_, err = c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("bar")))
})
})
Context("Writing", func() {
It("writes directly when acting as a client", func() {
c.pers = protocol.PerspectiveClient
_, err := c.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
It("only writes after flushing when acting as a server", func() {
c.pers = protocol.PerspectiveServer
_, err := c.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(BeEmpty())
err = c.Continue()
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
})
It("returns its remote address", func() {
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
c.remoteAddr = addr
Expect(c.RemoteAddr()).To(Equal(addr))
})
})

View file

@ -13,8 +13,8 @@ import (
) )
type extensionHandlerClient struct { type extensionHandlerClient struct {
params *TransportParameters ourParams *TransportParameters
paramsChan chan<- TransportParameters paramsChan chan TransportParameters
initialVersion protocol.VersionNumber initialVersion protocol.VersionNumber
supportedVersions []protocol.VersionNumber supportedVersions []protocol.VersionNumber
@ -22,16 +22,17 @@ type extensionHandlerClient struct {
} }
var _ mint.AppExtensionHandler = &extensionHandlerClient{} var _ mint.AppExtensionHandler = &extensionHandlerClient{}
var _ TLSExtensionHandler = &extensionHandlerClient{}
func newExtensionHandlerClient( func NewExtensionHandlerClient(
params *TransportParameters, params *TransportParameters,
paramsChan chan<- TransportParameters,
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
) *extensionHandlerClient { ) TLSExtensionHandler {
paramsChan := make(chan TransportParameters, 1)
return &extensionHandlerClient{ return &extensionHandlerClient{
params: params, ourParams: params,
paramsChan: paramsChan, paramsChan: paramsChan,
initialVersion: initialVersion, initialVersion: initialVersion,
supportedVersions: supportedVersions, supportedVersions: supportedVersions,
@ -46,7 +47,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
data, err := syntax.Marshal(clientHelloTransportParameters{ data, err := syntax.Marshal(clientHelloTransportParameters{
InitialVersion: uint32(h.initialVersion), InitialVersion: uint32(h.initialVersion),
Parameters: h.params.getTransportParameters(), Parameters: h.ourParams.getTransportParameters(),
}) })
if err != nil { if err != nil {
return err return err
@ -123,3 +124,7 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
h.paramsChan <- *params h.paramsChan <- *params
return nil return nil
} }
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View file

@ -13,15 +13,12 @@ import (
var _ = Describe("TLS Extension Handler, for the client", func() { var _ = Describe("TLS Extension Handler, for the client", func() {
var ( var (
handler *extensionHandlerClient handler *extensionHandlerClient
el mint.ExtensionList el mint.ExtensionList
paramsChan chan TransportParameters
) )
BeforeEach(func() { BeforeEach(func() {
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever).(*extensionHandlerClient)
paramsChan = make(chan TransportParameters, 1)
handler = newExtensionHandlerClient(&TransportParameters{}, paramsChan, protocol.VersionWhatever, nil, protocol.VersionWhatever)
el = make(mint.ExtensionList, 0) el = make(mint.ExtensionList, 0)
}) })
@ -81,7 +78,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var params TransportParameters var params TransportParameters
Expect(paramsChan).To(Receive(&params)) Expect(handler.GetPeerParams()).To(Receive(&params))
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344)) Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
}) })

View file

@ -14,26 +14,27 @@ import (
) )
type extensionHandlerServer struct { type extensionHandlerServer struct {
params *TransportParameters ourParams *TransportParameters
paramsChan chan<- TransportParameters paramsChan chan TransportParameters
version protocol.VersionNumber version protocol.VersionNumber
supportedVersions []protocol.VersionNumber supportedVersions []protocol.VersionNumber
} }
var _ mint.AppExtensionHandler = &extensionHandlerServer{} var _ mint.AppExtensionHandler = &extensionHandlerServer{}
var _ TLSExtensionHandler = &extensionHandlerServer{}
func newExtensionHandlerServer( func NewExtensionHandlerServer(
params *TransportParameters, params *TransportParameters,
paramsChan chan<- TransportParameters,
supportedVersions []protocol.VersionNumber, supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber, version protocol.VersionNumber,
) *extensionHandlerServer { ) TLSExtensionHandler {
paramsChan := make(chan TransportParameters, 1)
return &extensionHandlerServer{ return &extensionHandlerServer{
params: params, ourParams: params,
paramsChan: paramsChan, paramsChan: paramsChan,
version: version,
supportedVersions: supportedVersions, supportedVersions: supportedVersions,
version: version,
} }
} }
@ -43,7 +44,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
} }
transportParams := append( transportParams := append(
h.params.getTransportParameters(), h.ourParams.getTransportParameters(),
// TODO(#855): generate a real token // TODO(#855): generate a real token
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)}, transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
) )
@ -105,3 +106,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
h.paramsChan <- *params h.paramsChan <- *params
return nil return nil
} }
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View file

@ -20,15 +20,12 @@ func parameterMapToList(paramMap map[transportParameterID][]byte) []transportPar
var _ = Describe("TLS Extension Handler, for the server", func() { var _ = Describe("TLS Extension Handler, for the server", func() {
var ( var (
handler *extensionHandlerServer handler *extensionHandlerServer
el mint.ExtensionList el mint.ExtensionList
paramsChan chan TransportParameters
) )
BeforeEach(func() { BeforeEach(func() {
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever).(*extensionHandlerServer)
paramsChan = make(chan TransportParameters, 1)
handler = newExtensionHandlerServer(&TransportParameters{}, paramsChan, nil, protocol.VersionWhatever)
el = make(mint.ExtensionList, 0) el = make(mint.ExtensionList, 0)
}) })
@ -91,7 +88,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
err := handler.Receive(mint.HandshakeTypeClientHello, &el) err := handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var params TransportParameters var params TransportParameters
Expect(paramsChan).To(Receive(&params)) Expect(handler.GetPeerParams()).To(Receive(&params))
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344)) Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
}) })

View file

@ -1,6 +1,7 @@
package mocks package mocks
//go:generate sh -c "mockgen -source=../handshake/mint_utils.go -package mockhandshake -destination handshake/mint_tls.go" //go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS"
//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" //go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController" //go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD" //go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"

View file

@ -1,83 +1,106 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: ../handshake/mint_utils.go // Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS)
package mockhandshake package mockhandshake
import ( import (
io "io"
reflect "reflect" reflect "reflect"
mint "github.com/bifurcation/mint" mint "github.com/bifurcation/mint"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
// MockmintTLS is a mock of mintTLS interface // MockMintTLS is a mock of MintTLS interface
type MockmintTLS struct { type MockMintTLS struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockmintTLSMockRecorder recorder *MockMintTLSMockRecorder
} }
// MockmintTLSMockRecorder is the mock recorder for MockmintTLS // MockMintTLSMockRecorder is the mock recorder for MockMintTLS
type MockmintTLSMockRecorder struct { type MockMintTLSMockRecorder struct {
mock *MockmintTLS mock *MockMintTLS
} }
// NewMockmintTLS creates a new mock instance // NewMockMintTLS creates a new mock instance
func NewMockmintTLS(ctrl *gomock.Controller) *MockmintTLS { func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS {
mock := &MockmintTLS{ctrl: ctrl} mock := &MockMintTLS{ctrl: ctrl}
mock.recorder = &MockmintTLSMockRecorder{mock} mock.recorder = &MockMintTLSMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use // EXPECT returns an object that allows the caller to indicate expected use
func (_m *MockmintTLS) EXPECT() *MockmintTLSMockRecorder { func (_m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder {
return _m.recorder return _m.recorder
} }
// GetCipherSuite mocks base method
func (_m *MockmintTLS) GetCipherSuite() mint.CipherSuiteParams {
ret := _m.ctrl.Call(_m, "GetCipherSuite")
ret0, _ := ret[0].(mint.CipherSuiteParams)
return ret0
}
// GetCipherSuite indicates an expected call of GetCipherSuite
func (_mr *MockmintTLSMockRecorder) GetCipherSuite() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockmintTLS)(nil).GetCipherSuite))
}
// ComputeExporter mocks base method // ComputeExporter mocks base method
func (_m *MockmintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { func (_m *MockMintTLS) ComputeExporter(_param0 string, _param1 []byte, _param2 int) ([]byte, error) {
ret := _m.ctrl.Call(_m, "ComputeExporter", label, context, keyLength) ret := _m.ctrl.Call(_m, "ComputeExporter", _param0, _param1, _param2)
ret0, _ := ret[0].([]byte) ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// ComputeExporter indicates an expected call of ComputeExporter // ComputeExporter indicates an expected call of ComputeExporter
func (_mr *MockmintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call { func (_mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockmintTLS)(nil).ComputeExporter), arg0, arg1, arg2) return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
}
// GetCipherSuite mocks base method
func (_m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams {
ret := _m.ctrl.Call(_m, "GetCipherSuite")
ret0, _ := ret[0].(mint.CipherSuiteParams)
return ret0
}
// GetCipherSuite indicates an expected call of GetCipherSuite
func (_mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite))
} }
// Handshake mocks base method // Handshake mocks base method
func (_m *MockmintTLS) Handshake() mint.Alert { func (_m *MockMintTLS) Handshake() mint.Alert {
ret := _m.ctrl.Call(_m, "Handshake") ret := _m.ctrl.Call(_m, "Handshake")
ret0, _ := ret[0].(mint.Alert) ret0, _ := ret[0].(mint.Alert)
return ret0 return ret0
} }
// Handshake indicates an expected call of Handshake // Handshake indicates an expected call of Handshake
func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call { func (_mr *MockMintTLSMockRecorder) Handshake() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake)) return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake))
}
// SetCryptoStream mocks base method
func (_m *MockMintTLS) SetCryptoStream(_param0 io.ReadWriter) {
_m.ctrl.Call(_m, "SetCryptoStream", _param0)
}
// SetCryptoStream indicates an expected call of SetCryptoStream
func (_mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0)
}
// SetExtensionHandler mocks base method
func (_m *MockMintTLS) SetExtensionHandler(_param0 mint.AppExtensionHandler) error {
ret := _m.ctrl.Call(_m, "SetExtensionHandler", _param0)
ret0, _ := ret[0].(error)
return ret0
}
// SetExtensionHandler indicates an expected call of SetExtensionHandler
func (_mr *MockMintTLSMockRecorder) SetExtensionHandler(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetExtensionHandler", reflect.TypeOf((*MockMintTLS)(nil).SetExtensionHandler), arg0)
} }
// State mocks base method // State mocks base method
func (_m *MockmintTLS) State() mint.ConnectionState { func (_m *MockMintTLS) State() mint.State {
ret := _m.ctrl.Call(_m, "State") ret := _m.ctrl.Call(_m, "State")
ret0, _ := ret[0].(mint.ConnectionState) ret0, _ := ret[0].(mint.State)
return ret0 return ret0
} }
// State indicates an expected call of State // State indicates an expected call of State
func (_mr *MockmintTLSMockRecorder) State() *gomock.Call { func (_mr *MockMintTLSMockRecorder) State() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockmintTLS)(nil).State)) return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State))
} }

View file

@ -0,0 +1,71 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler)
package mocks
import (
reflect "reflect"
mint "github.com/bifurcation/mint"
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
)
// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface
type MockTLSExtensionHandler struct {
ctrl *gomock.Controller
recorder *MockTLSExtensionHandlerMockRecorder
}
// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler
type MockTLSExtensionHandlerMockRecorder struct {
mock *MockTLSExtensionHandler
}
// NewMockTLSExtensionHandler creates a new mock instance
func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler {
mock := &MockTLSExtensionHandler{ctrl: ctrl}
mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (_m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder {
return _m.recorder
}
// GetPeerParams mocks base method
func (_m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters {
ret := _m.ctrl.Call(_m, "GetPeerParams")
ret0, _ := ret[0].(<-chan handshake.TransportParameters)
return ret0
}
// GetPeerParams indicates an expected call of GetPeerParams
func (_mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams))
}
// Receive mocks base method
func (_m *MockTLSExtensionHandler) Receive(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error {
ret := _m.ctrl.Call(_m, "Receive", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
// Receive indicates an expected call of Receive
func (_mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1)
}
// Send mocks base method
func (_m *MockTLSExtensionHandler) Send(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error {
ret := _m.ctrl.Call(_m, "Send", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
// Send indicates an expected call of Send
func (_mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1)
}

150
mint_utils.go Normal file
View file

@ -0,0 +1,150 @@
package quic
import (
"bytes"
gocrypto "crypto"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type mintController struct {
csc *handshake.CryptoStreamConn
conn *mint.Conn
}
var _ handshake.MintTLS = &mintController{}
func newMintController(
csc *handshake.CryptoStreamConn,
mconf *mint.Config,
pers protocol.Perspective,
) handshake.MintTLS {
var conn *mint.Conn
if pers == protocol.PerspectiveClient {
conn = mint.Client(csc, mconf)
} else {
conn = mint.Server(csc, mconf)
}
return &mintController{
csc: csc,
conn: conn,
}
}
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
return mc.conn.State().CipherSuite
}
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
return mc.conn.ComputeExporter(label, context, keyLength)
}
func (mc *mintController) Handshake() mint.Alert {
return mc.conn.Handshake()
}
func (mc *mintController) State() mint.State {
return mc.conn.State().HandshakeState
}
func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
mc.csc.SetStream(stream)
}
func (mc *mintController) SetExtensionHandler(h mint.AppExtensionHandler) error {
return mc.conn.SetExtensionHandler(h)
}
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
mconf := &mint.Config{
NonBlocking: true,
CipherSuites: []mint.CipherSuite{
mint.TLS_AES_128_GCM_SHA256,
mint.TLS_AES_256_GCM_SHA384,
},
}
if tlsConf != nil {
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
for i, certChain := range tlsConf.Certificates {
mconf.Certificates[i] = &mint.Certificate{
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
}
for j, cert := range certChain.Certificate {
c, err := x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
mconf.Certificates[i].Chain[j] = c
}
}
}
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
return nil, err
}
return mconf, nil
}
// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets
// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0.
func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.StreamFrame, error) {
unpacker := &packetUnpacker{aead: &nullAEAD{aead}, version: version}
packet, err := unpacker.Unpack(hdr.Raw, hdr, data)
if err != nil {
return nil, err
}
var frame *wire.StreamFrame
for _, f := range packet.frames {
var ok bool
frame, ok = f.(*wire.StreamFrame)
if ok {
break
}
}
if frame == nil {
return nil, errors.New("Packet doesn't contain a STREAM_FRAME")
}
// We don't need a check for the stream ID here.
// The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream.
if frame.Offset != 0 {
return nil, errors.New("received stream data with non-zero offset")
}
if utils.Debug() {
utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID)
hdr.Log()
wire.LogFrame(frame, false)
}
return frame, nil
}
// packUnencryptedPacket provides a low-overhead way to pack a packet.
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) {
raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw)
if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
if err := sf.Write(buffer, hdr.Version); err != nil {
return nil, err
}
raw = raw[0:buffer.Len()]
_ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+aead.Overhead()]
if utils.Debug() {
utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted)
hdr.Log()
wire.LogFrame(sf, true)
}
return raw, nil
}

111
mint_utils_test.go Normal file
View file

@ -0,0 +1,111 @@
package quic
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Packing and unpacking Initial packets", func() {
var aead crypto.AEAD
connID := protocol.ConnectionID(0x1337)
ver := protocol.VersionTLS
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
PacketNumber: 0x42,
ConnectionID: connID,
Version: ver,
}
BeforeEach(func() {
var err error
aead, err = crypto.NewNullAEAD(protocol.PerspectiveServer, connID, ver)
Expect(err).ToNot(HaveOccurred())
// set hdr.Raw
buf := &bytes.Buffer{}
err = hdr.Write(buf, protocol.PerspectiveServer, ver)
Expect(err).ToNot(HaveOccurred())
hdr.Raw = buf.Bytes()
})
Context("unpacking", func() {
packPacket := func(frames []wire.Frame) []byte {
buf := &bytes.Buffer{}
err := hdr.Write(buf, protocol.PerspectiveClient, ver)
Expect(err).ToNot(HaveOccurred())
payloadStartIndex := buf.Len()
aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver)
for _, f := range frames {
err := f.Write(buf, ver)
Expect(err).ToNot(HaveOccurred())
}
raw := buf.Bytes()
return aeadCl.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
}
It("unpacks a packet", func() {
f := &wire.StreamFrame{
StreamID: 0,
Data: []byte("foobar"),
}
p := packPacket([]wire.Frame{f})
frame, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
})
It("rejects a packet that doesn't contain a STREAM_FRAME", func() {
p := packPacket([]wire.Frame{&wire.PingFrame{}})
_, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).To(MatchError("Packet doesn't contain a STREAM_FRAME"))
})
It("rejects a packet that has a STREAM_FRAME for the wrong stream", func() {
f := &wire.StreamFrame{
StreamID: 42,
Data: []byte("foobar"),
}
p := packPacket([]wire.Frame{f})
_, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).To(MatchError("UnencryptedStreamData: received unencrypted stream data on stream 42"))
})
It("rejects a packet that has a STREAM_FRAME with a non-zero offset", func() {
f := &wire.StreamFrame{
StreamID: 0,
Offset: 10,
Data: []byte("foobar"),
}
p := packPacket([]wire.Frame{f})
_, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).To(MatchError("received stream data with non-zero offset"))
})
})
Context("packing", func() {
var unpacker *packetUnpacker
BeforeEach(func() {
aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver)
Expect(err).ToNot(HaveOccurred())
unpacker = &packetUnpacker{aead: &nullAEAD{aeadCl}, version: ver}
})
It("packs a packet", func() {
f := &wire.StreamFrame{
Data: []byte("foobar"),
FinBit: true,
}
data, err := packUnencryptedPacket(aead, hdr, f, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
packet, err := unpacker.Unpack(hdr.Raw, hdr, data[len(hdr.Raw):])
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
})
})

View file

@ -17,9 +17,9 @@ type packetNumberGenerator struct {
nextToSkip protocol.PacketNumber nextToSkip protocol.PacketNumber
} }
func newPacketNumberGenerator(averagePeriod protocol.PacketNumber) *packetNumberGenerator { func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
return &packetNumberGenerator{ return &packetNumberGenerator{
next: 1, next: initial,
averagePeriod: averagePeriod, averagePeriod: averagePeriod,
} }
} }

View file

@ -12,7 +12,12 @@ var _ = Describe("Packet Number Generator", func() {
var png packetNumberGenerator var png packetNumberGenerator
BeforeEach(func() { BeforeEach(func() {
png = *newPacketNumberGenerator(100) png = *newPacketNumberGenerator(1, 100)
})
It("can be initialized to return any first packet number", func() {
png = *newPacketNumberGenerator(12345, 100)
Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345)))
}) })
It("gets 1 as the first packet number", func() { It("gets 1 as the first packet number", func() {

View file

@ -32,9 +32,11 @@ type packetPacker struct {
ackFrame *wire.AckFrame ackFrame *wire.AckFrame
leastUnacked protocol.PacketNumber leastUnacked protocol.PacketNumber
omitConnectionID bool omitConnectionID bool
hasSentPacket bool // has the packetPacker already sent a packet
} }
func newPacketPacker(connectionID protocol.ConnectionID, func newPacketPacker(connectionID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber,
cryptoSetup handshake.CryptoSetup, cryptoSetup handshake.CryptoSetup,
streamFramer *streamFramer, streamFramer *streamFramer,
perspective protocol.Perspective, perspective protocol.Perspective,
@ -46,7 +48,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
perspective: perspective, perspective: perspective,
version: version, version: version,
streamFramer: streamFramer, streamFramer: streamFramer,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
} }
} }
@ -116,7 +118,12 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
// PackPacket packs a new packet // PackPacket packs a new packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket() (*packedPacket, error) { func (p *packetPacker) PackPacket() (*packedPacket, error) {
if p.streamFramer.HasCryptoStreamFrame() { hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame()
// if this is the first packet to be send, make sure it contains stream data
if !p.hasSentPacket && !hasCryptoStreamFrame {
return nil, nil
}
if hasCryptoStreamFrame {
return p.packCryptoPacket() return p.packCryptoPacket()
} }
@ -266,18 +273,21 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
pnum := p.packetNumberGenerator.Peek() pnum := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked) packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked)
var isLongHeader bool
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
// TODO: set the Long Header type
packetNumberLen = protocol.PacketNumberLen4
isLongHeader = true
}
header := &wire.Header{ header := &wire.Header{
ConnectionID: p.connectionID, ConnectionID: p.connectionID,
PacketNumber: pnum, PacketNumber: pnum,
PacketNumberLen: packetNumberLen, PacketNumberLen: packetNumberLen,
IsLongHeader: isLongHeader, }
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
header.PacketNumberLen = protocol.PacketNumberLen4
header.IsLongHeader = true
if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient {
header.Type = protocol.PacketTypeInitial
// TODO(#886): add padding
} else {
header.Type = protocol.PacketTypeHandshake
}
} }
if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure { if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure {
@ -292,7 +302,6 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
header.Version = p.version header.Version = p.version
} }
} else { } else {
header.Type = p.cryptoSetup.GetNextPacketType()
if encLevel != protocol.EncryptionForwardSecure { if encLevel != protocol.EncryptionForwardSecure {
header.Version = p.version header.Version = p.version
} }
@ -330,7 +339,7 @@ func (p *packetPacker) writeAndSealPacket(
if num != header.PacketNumber { if num != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
} }
p.hasSentPacket = true
return raw, nil return raw, nil
} }

View file

@ -28,7 +28,6 @@ type mockCryptoSetup struct {
divNonce []byte divNonce []byte
encLevelSeal protocol.EncryptionLevel encLevelSeal protocol.EncryptionLevel
encLevelSealCrypto protocol.EncryptionLevel encLevelSealCrypto protocol.EncryptionLevel
nextPacketType protocol.PacketType
} }
var _ handshake.CryptoSetup = &mockCryptoSetup{} var _ handshake.CryptoSetup = &mockCryptoSetup{}
@ -50,7 +49,6 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
} }
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce } func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
func (m *mockCryptoSetup) GetNextPacketType() protocol.PacketType { return m.nextPacketType }
var _ = Describe("Packet packer", func() { var _ = Describe("Packet packer", func() {
var ( var (
@ -69,13 +67,14 @@ var _ = Describe("Packet packer", func() {
packer = &packetPacker{ packer = &packetPacker{
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
connectionID: 0x1337, connectionID: 0x1337,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
streamFramer: streamFramer, streamFramer: streamFramer,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
} }
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
packer.version = protocol.VersionWhatever packer.version = protocol.VersionWhatever
packer.hasSentPacket = true
}) })
It("returns nil when no packet is queued", func() { It("returns nil when no packet is queued", func() {
@ -191,13 +190,6 @@ var _ = Describe("Packet packer", func() {
Expect(h.Version).To(Equal(versionIETFHeader)) Expect(h.Version).To(Equal(versionIETFHeader))
}) })
It("sets the packet type based on the state of the handshake", func() {
packer.cryptoSetup.(*mockCryptoSetup).nextPacketType = 5
h := packer.getHeader(protocol.EncryptionSecure)
Expect(h.IsLongHeader).To(BeTrue())
Expect(h.Type).To(Equal(protocol.PacketType(5)))
})
It("uses the Short Header format for forward-secure packets", func() { It("uses the Short Header format for forward-secure packets", func() {
h := packer.getHeader(protocol.EncryptionForwardSecure) h := packer.getHeader(protocol.EncryptionForwardSecure)
Expect(h.IsLongHeader).To(BeFalse()) Expect(h.IsLongHeader).To(BeFalse())
@ -269,7 +261,7 @@ var _ = Describe("Packet packer", func() {
Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber)) Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber))
}) })
It("packs a StopWaitingFrame first", func() { It("packs a STOP_WAITING frame first", func() {
packer.packetNumberGenerator.next = 15 packer.packetNumberGenerator.next = 15
swf := &wire.StopWaitingFrame{LeastUnacked: 10} swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.RstStreamFrame{})
@ -281,7 +273,7 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames[0]).To(Equal(swf)) Expect(p.frames[0]).To(Equal(swf))
}) })
It("sets the LeastUnackedDelta length of a StopWaitingFrame", func() { It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() {
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
packer.packetNumberGenerator.next = packetNumber packer.packetNumberGenerator.next = packetNumber
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
@ -292,7 +284,7 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
}) })
It("does not pack a packet containing only a StopWaitingFrame", func() { It("does not pack a packet containing only a STOP_WAITING frame", func() {
swf := &wire.StopWaitingFrame{LeastUnacked: 10} swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(swf) packer.QueueControlFrame(swf)
p, err := packer.PackPacket() p, err := packer.PackPacket()
@ -307,6 +299,14 @@ var _ = Describe("Packet packer", func() {
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
}) })
It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() {
packer.hasSentPacket = false
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
It("packs many control frames into 1 packets", func() { It("packs many control frames into 1 packets", func() {
f := &wire.AckFrame{LargestAcked: 1} f := &wire.AckFrame{LargestAcked: 1}
b := &bytes.Buffer{} b := &bytes.Buffer{}
@ -602,7 +602,7 @@ var _ = Describe("Packet packer", func() {
}) })
}) })
Context("Blocked frames", func() { Context("BLOCKED frames", func() {
It("queues a BLOCKED frame", func() { It("queues a BLOCKED frame", func() {
length := 100 length := 100
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}} streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
@ -750,7 +750,7 @@ var _ = Describe("Packet packer", func() {
Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment")) Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment"))
}) })
It("refuses to retransmit packets without a StopWaitingFrame", func() { It("refuses to retransmit packets without a STOP_WAITING Frame", func() {
packer.stopWaiting = nil packer.stopWaiting = nil
_, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{ _, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.EncryptionSecure, EncryptionLevel: protocol.EncryptionSecure,

106
server.go
View file

@ -19,6 +19,7 @@ import (
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
Session Session
getCryptoStream() cryptoStream
handshakeStatus() <-chan handshakeEvent handshakeStatus() <-chan handshakeEvent
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber GetVersion() protocol.VersionNumber
@ -33,6 +34,9 @@ type server struct {
conn net.PacketConn conn net.PacketConn
supportsTLS bool
serverTLS *serverTLS
certChain crypto.CertChain certChain crypto.CertChain
scfg *handshake.ServerConfig scfg *handshake.ServerConfig
@ -77,11 +81,21 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
if err != nil { if err != nil {
return nil, err return nil, err
} }
config = populateServerConfig(config)
// check if any of the supported versions supports TLS
var supportsTLS bool
for _, v := range config.Versions {
if v.UsesTLS() {
supportsTLS = true
break
}
}
s := &server{ s := &server{
conn: conn, conn: conn,
tlsConf: tlsConf, tlsConf: tlsConf,
config: populateServerConfig(config), config: config,
certChain: certChain, certChain: certChain,
scfg: scfg, scfg: scfg,
sessions: map[protocol.ConnectionID]packetHandler{}, sessions: map[protocol.ConnectionID]packetHandler{},
@ -89,12 +103,47 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
sessionQueue: make(chan Session, 5), sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
supportsTLS: supportsTLS,
}
if supportsTLS {
if err := s.setupTLS(); err != nil {
return nil, err
}
} }
go s.serve() go s.serve()
utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil return s, nil
} }
func (s *server) setupTLS() error {
cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie)
if err != nil {
return err
}
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf)
if err != nil {
return err
}
s.serverTLS = serverTLS
// handle TLS connection establishment statelessly
go func() {
for {
select {
case <-s.errorChan:
return
case sess := <-sessionChan:
// TODO: think about what to do with connection ID collisions
connID := sess.(*session).connectionID
s.sessionsMutex.Lock()
s.sessions[connID] = sess
s.sessionsMutex.Unlock()
s.runHandshakeAndSession(sess, connID)
}
}
}()
return nil
}
var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool { var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
if cookie == nil { if cookie == nil {
return false return false
@ -225,8 +274,16 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return qerr.Error(qerr.InvalidPacketHeader, err.Error()) return qerr.Error(qerr.InvalidPacketHeader, err.Error())
} }
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
packetData := packet[len(packet)-r.Len():]
connID := hdr.ConnectionID connID := hdr.ConnectionID
if hdr.Type == protocol.PacketTypeInitial {
if s.supportsTLS {
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
}
return nil
}
s.sessionsMutex.RLock() s.sessionsMutex.RLock()
session, sessionKnown := s.sessions[connID] session, sessionKnown := s.sessions[connID]
s.sessionsMutex.RUnlock() s.sessionsMutex.RUnlock()
@ -279,11 +336,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return err return err
} }
} }
// send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header
if hdr.Type == protocol.PacketTypeInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
_, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.config.Versions), remoteAddr)
return err
}
if !sessionKnown { if !sessionKnown {
version := hdr.Version version := hdr.Version
@ -307,34 +359,38 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
s.sessions[connID] = session s.sessions[connID] = session
s.sessionsMutex.Unlock() s.sessionsMutex.Unlock()
go func() { s.runHandshakeAndSession(session, connID)
// session.run() returns as soon as the session is closed
_ = session.run()
s.removeConnection(connID)
}()
go func() {
for {
ev := <-session.handshakeStatus()
if ev.err != nil {
return
}
if ev.encLevel == protocol.EncryptionForwardSecure {
break
}
}
s.sessionQueue <- session
}()
} }
session.handlePacket(&receivedPacket{ session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
header: hdr, header: hdr,
data: packet[len(packet)-r.Len():], data: packetData,
rcvTime: rcvTime, rcvTime: rcvTime,
}) })
return nil return nil
} }
func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) {
go func() {
_ = session.run()
// session.run() returns as soon as the session is closed
s.removeConnection(connID)
}()
go func() {
for {
ev := <-session.handshakeStatus()
if ev.err != nil {
return
}
if ev.encLevel == protocol.EncryptionForwardSecure {
break
}
}
s.sessionQueue <- session
}()
}
func (s *server) removeConnection(id protocol.ConnectionID) { func (s *server) removeConnection(id protocol.ConnectionID) {
s.sessionsMutex.Lock() s.sessionsMutex.Lock()
s.sessions[id] = nil s.sessions[id] = nil

View file

@ -12,6 +12,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
@ -67,6 +68,7 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not imple
func (*mockSession) Context() context.Context { panic("not implemented") } func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever } func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan } func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") }
var _ Session = &mockSession{} var _ Session = &mockSession{}
var _ NonFWSession = &mockSession{} var _ NonFWSession = &mockSession{}
@ -96,7 +98,8 @@ var _ = Describe("Server", func() {
) )
BeforeEach(func() { BeforeEach(func() {
conn = &mockPacketConn{addr: &net.UDPAddr{}} conn = newMockPacketConn()
conn.addr = &net.UDPAddr{}
config = &Config{Versions: protocol.SupportedVersions} config = &Config{Versions: protocol.SupportedVersions}
}) })
@ -235,14 +238,14 @@ var _ = Describe("Server", func() {
}) })
It("works if no quic.Config is given", func(done Done) { It("works if no quic.Config is given", func(done Done) {
ln, err := ListenAddr("127.0.0.1:0", nil, config) ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(ln.Close()).To(Succeed()) Expect(ln.Close()).To(Succeed())
close(done) close(done)
}, 1) }, 1)
It("closes properly", func() { It("closes properly", func() {
ln, err := ListenAddr("127.0.0.1:0", nil, config) ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
var returned bool var returned bool
@ -409,7 +412,7 @@ var _ = Describe("Server", func() {
} }
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
conn.dataToRead = b.Bytes() conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr conn.dataReadFrom = udpAddr
ln, err := Listen(conn, nil, config) ln, err := Listen(conn, nil, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -432,7 +435,7 @@ var _ = Describe("Server", func() {
}) })
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() { It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
config.Versions = []protocol.VersionNumber{99} config.Versions = []protocol.VersionNumber{99, protocol.VersionTLS}
b := &bytes.Buffer{} b := &bytes.Buffer{}
hdr := wire.Header{ hdr := wire.Header{
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -441,11 +444,12 @@ var _ = Describe("Server", func() {
PacketNumber: 0x55, PacketNumber: 0x55,
Version: 0x1234, Version: 0x1234,
} }
hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
conn.dataToRead = b.Bytes() conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr conn.dataReadFrom = udpAddr
ln, err := Listen(conn, nil, config) ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
done := make(chan struct{}) done := make(chan struct{})
@ -466,9 +470,32 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
}) })
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
version := protocol.VersionNumber(99)
Expect(version.UsesTLS()).To(BeFalse())
config.Versions = []protocol.VersionNumber{version}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
IsLongHeader: true,
ConnectionID: 0x1337,
PacketNumber: 0x55,
Version: protocol.VersionTLS,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
defer ln.Close()
Expect(err).ToNot(HaveOccurred())
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
})
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() { It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
conn.dataReadFrom = udpAddr conn.dataReadFrom = udpAddr
conn.dataToRead = []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01} conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
ln, err := Listen(conn, nil, config) ln, err := Listen(conn, nil, config)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go func() {

179
server_tls.go Normal file
View file

@ -0,0 +1,179 @@
package quic
import (
"crypto/tls"
"fmt"
"net"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type nullAEAD struct {
aead crypto.AEAD
}
var _ quicAEAD = &nullAEAD{}
func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
data, err := n.aead.Open(dst, src, packetNumber, associatedData)
return data, protocol.EncryptionUnencrypted, err
}
type serverTLS struct {
conn net.PacketConn
config *Config
supportedVersions []protocol.VersionNumber
mintConf *mint.Config
cookieProtector mint.CookieProtector
params *handshake.TransportParameters
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
sessionChan chan<- packetHandler
}
func newServerTLS(
conn net.PacketConn,
config *Config,
cookieHandler *handshake.CookieHandler,
tlsConf *tls.Config,
) (*serverTLS, <-chan packetHandler, error) {
mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer)
if err != nil {
return nil, nil, err
}
mconf.RequireCookie = true
cs, err := mint.NewDefaultCookieProtector()
if err != nil {
return nil, nil, err
}
mconf.CookieProtector = cs
mconf.CookieHandler = cookieHandler
sessionChan := make(chan packetHandler)
s := &serverTLS{
conn: conn,
config: config,
supportedVersions: config.Versions,
mintConf: mconf,
sessionChan: sessionChan,
params: &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: config.IdleTimeout,
},
}
s.newMintConn = s.newMintConnImpl
return s, sessionChan, nil
}
func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) {
utils.Debugf("Received a Packet. Handling it statelessly.")
sess, err := s.handleInitialImpl(remoteAddr, hdr, data)
if err != nil {
utils.Errorf("Error occured handling initial packet: %s", err)
return
}
if sess == nil { // a stateless reset was done
return
}
s.sessionChan <- sess
}
// will be set to s.newMintConn by the constructor
func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
conn := mint.Server(bc, s.mintConf)
extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v)
if err := conn.SetExtensionHandler(extHandler); err != nil {
return nil, nil, err
}
tls := newMintController(bc, s.mintConf, protocol.PerspectiveServer)
tls.SetExtensionHandler(extHandler)
return tls, extHandler.GetPeerParams(), nil
}
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
// TODO: check length requirement
// check version, if not matching send VNP
if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
_, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.supportedVersions), remoteAddr)
return nil, err
}
// unpack packet and check stream frame contents
version := hdr.Version
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version)
if err != nil {
return nil, err
}
frame, err := unpackInitialPacket(aead, hdr, data, version)
if err != nil {
utils.Debugf("Error unpacking initial packet: %s", err)
return nil, nil
}
bc := handshake.NewCryptoStreamConn(remoteAddr)
bc.AddDataForReading(frame.Data)
tls, paramsChan, err := s.newMintConn(bc, hdr.Version)
if err != nil {
return nil, err
}
alert := tls.Handshake()
if alert == mint.AlertStatelessRetry {
// the HelloRetryRequest was written to the bufferConn
// Take that data and write send a Retry packet
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
ConnectionID: hdr.ConnectionID, // echo the client's connection ID
PacketNumber: hdr.PacketNumber, // echo the client's packet number
Version: version,
}
f := &wire.StreamFrame{
StreamID: version.CryptoStreamID(),
Data: bc.GetDataForWriting(),
}
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer)
if err != nil {
return nil, err
}
_, err = s.conn.WriteTo(data, remoteAddr)
return nil, err
}
if alert != mint.AlertNoAlert {
return nil, alert
}
if tls.State() != mint.StateServerNegotiated {
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
}
if alert := tls.Handshake(); alert != mint.AlertNoAlert {
return nil, alert
}
if tls.State() != mint.StateServerWaitFlight2 {
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
}
params := <-paramsChan
sess, err := newTLSServerSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
hdr.ConnectionID, // TODO: we can use a server-chosen connection ID here
protocol.PacketNumber(1), // TODO: use a random packet number here
s.config,
tls,
bc,
aead,
&params,
version,
)
if err != nil {
return nil, err
}
cs := sess.getCryptoStream()
cs.SetReadOffset(frame.DataLen())
bc.SetStream(cs)
return sess, nil
}

116
server_tls_test.go Normal file
View file

@ -0,0 +1,116 @@
package quic
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Stateless TLS handling", func() {
var (
conn *mockPacketConn
server *serverTLS
sessionChan <-chan packetHandler
mintTLS *mockhandshake.MockMintTLS
extHandler *mocks.MockTLSExtensionHandler
mintReply io.Writer
)
BeforeEach(func() {
mintTLS = mockhandshake.NewMockMintTLS(mockCtrl)
extHandler = mocks.NewMockTLSExtensionHandler(mockCtrl)
conn = newMockPacketConn()
config := &Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
var err error
server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
mintReply = bc
return mintTLS, extHandler.GetPeerParams(), nil
}
})
getPacket := func(f wire.Frame) (*wire.Header, []byte) {
hdrBuf := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
PacketNumber: 1,
Version: protocol.VersionTLS,
}
err := hdr.Write(hdrBuf, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
hdr.Raw = hdrBuf.Bytes()
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, 0, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
buf := &bytes.Buffer{}
err = f.Write(buf, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
return hdr, aead.Seal(nil, buf.Bytes(), 1, hdr.Raw)
}
It("sends a version negotiation packet if it doesn't support the version", func() {
server.HandleInitial(nil, &wire.Header{Version: 0x1337}, nil)
Expect(conn.dataWritten.Len()).ToNot(BeZero())
hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionUnknown)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsVersionNegotiation).To(BeTrue())
Expect(sessionChan).ToNot(Receive())
})
It("ignores packets with invalid contents", func() {
hdr, data := getPacket(&wire.StreamFrame{StreamID: 10, Offset: 11, Data: []byte("foobar")})
server.HandleInitial(nil, hdr, data)
Expect(conn.dataWritten.Len()).To(BeZero())
Expect(sessionChan).ToNot(Receive())
})
It("replies with a Retry packet, if a Cookie is required", func() {
extHandler.EXPECT().GetPeerParams()
mintTLS.EXPECT().Handshake().Return(mint.AlertStatelessRetry).Do(func() {
mintReply.Write([]byte("Retry with this Cookie"))
})
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
server.HandleInitial(nil, hdr, data)
Expect(conn.dataWritten.Len()).ToNot(BeZero())
hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(sessionChan).ToNot(Receive())
})
It("replies with a Handshake packet and creates a session, if no Cookie is required", func() {
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert).Do(func() {
mintReply.Write([]byte("Server Hello"))
})
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert)
mintTLS.EXPECT().State().Return(mint.StateServerNegotiated)
mintTLS.EXPECT().State().Return(mint.StateServerWaitFlight2)
paramsChan := make(chan handshake.TransportParameters, 1)
paramsChan <- handshake.TransportParameters{}
extHandler.EXPECT().GetPeerParams().Return(paramsChan)
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.HandleInitial(nil, hdr, data)
// the Handshake packet is written by the session
Expect(conn.dataWritten.Len()).To(BeZero())
close(done)
}()
Eventually(sessionChan).Should(Receive())
Eventually(done).Should(BeClosed())
})
})

View file

@ -11,6 +11,7 @@ import (
"github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/protocol"
@ -60,7 +61,7 @@ type session struct {
conn connection conn connection
streamsMap *streamsMap streamsMap *streamsMap
cryptoStream streamI cryptoStream cryptoStream
rttStats *congestion.RTTStats rttStats *congestion.RTTStats
@ -126,21 +127,48 @@ func newSession(
conn connection, conn connection,
v protocol.VersionNumber, v protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
sCfg *handshake.ServerConfig, scfg *handshake.ServerConfig,
tlsConf *tls.Config, tlsConf *tls.Config,
config *Config, config *Config,
) (packetHandler, error) { ) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
config: config, config: config,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
} }
return s, s.setup(sCfg, "", tlsConf, v, nil) s.preSetup()
transportParams := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: s.config.IdleTimeout,
}
cs, err := newCryptoSetup(
s.cryptoStream,
s.connectionID,
s.conn.RemoteAddr(),
s.version,
scfg,
transportParams,
s.config.Versions,
s.config.AcceptCookie,
paramsChan,
aeadChanged,
)
if err != nil {
return nil, err
}
s.cryptoSetup = cs
return s, s.postSetup(1)
} }
// declare this as a variable, such that we can it mock it in the tests // declare this as a variable, so that we can it mock it in the tests
var newClientSession = func( var newClientSession = func(
conn connection, conn connection,
hostname string, hostname string,
@ -151,27 +179,130 @@ var newClientSession = func(
initialVersion protocol.VersionNumber, initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
) (packetHandler, error) { ) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
config: config, config: config,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
} }
return s, s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions) s.preSetup()
transportParams := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: s.config.IdleTimeout,
OmitConnectionID: s.config.RequestConnectionIDOmission,
}
cs, err := newCryptoSetupClient(
s.cryptoStream,
hostname,
s.connectionID,
s.version,
tlsConf,
transportParams,
paramsChan,
aeadChanged,
initialVersion,
negotiatedVersions,
)
if err != nil {
return nil, err
}
s.cryptoSetup = cs
return s, s.postSetup(1)
} }
func (s *session) setup( func newTLSServerSession(
scfg *handshake.ServerConfig, conn connection,
hostname string, connectionID protocol.ConnectionID,
tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber,
initialVersion protocol.VersionNumber, config *Config,
negotiatedVersions []protocol.VersionNumber, tls handshake.MintTLS,
) error { cryptoStreamConn *handshake.CryptoStreamConn,
nullAEAD crypto.AEAD,
peerParams *handshake.TransportParameters,
v protocol.VersionNumber,
) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2) aeadChanged := make(chan protocol.EncryptionLevel, 2)
paramsChan := make(chan handshake.TransportParameters) s := &session{
s.aeadChanged = aeadChanged conn: conn,
s.paramsChan = paramsChan config: config,
connectionID: connectionID,
perspective: protocol.PerspectiveServer,
version: v,
aeadChanged: aeadChanged,
}
s.preSetup()
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
tls,
cryptoStreamConn,
nullAEAD,
aeadChanged,
v,
)
if err := s.postSetup(initialPacketNumber); err != nil {
return nil, err
}
s.peerParams = peerParams
s.processTransportParameters(peerParams)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, nil
}
// declare this as a variable, such that we can it mock it in the tests
var newTLSClientSession = func(
conn connection,
hostname string,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
config *Config,
tls handshake.MintTLS,
paramsChan <-chan handshake.TransportParameters,
initialPacketNumber protocol.PacketNumber,
) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s := &session{
conn: conn,
config: config,
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
}
s.preSetup()
tls.SetCryptoStream(s.cryptoStream)
cs, err := handshake.NewCryptoSetupTLSClient(
s.cryptoStream,
s.connectionID,
hostname,
aeadChanged,
tls,
v,
)
if err != nil {
return nil, err
}
s.cryptoSetup = cs
return s, s.postSetup(initialPacketNumber)
}
func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{}
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream)
}
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.handshakeChan = make(chan handshakeEvent, 3) s.handshakeChan = make(chan handshakeEvent, 3)
s.handshakeCompleteChan = make(chan error, 1) s.handshakeCompleteChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
@ -185,91 +316,14 @@ func (s *session) setup(
s.lastNetworkActivityTime = now s.lastNetworkActivityTime = now
s.sessionCreationTime = now s.sessionCreationTime = now
s.rttStats = &congestion.RTTStats{}
transportParams := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: s.config.IdleTimeout,
}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats) s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version) s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version) s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
s.cryptoStream = s.newStream(s.version.CryptoStreamID())
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController) s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController)
var err error
if s.perspective == protocol.PerspectiveServer {
verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool {
return s.config.AcceptCookie(clientAddr, cookie)
}
if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
s.cryptoStream,
s.connectionID,
tlsConf,
s.conn.RemoteAddr(),
transportParams,
paramsChan,
aeadChanged,
verifySourceAddr,
s.config.Versions,
s.version,
)
} else {
s.cryptoSetup, err = newCryptoSetup(
s.cryptoStream,
s.connectionID,
s.conn.RemoteAddr(),
s.version,
scfg,
transportParams,
s.config.Versions,
verifySourceAddr,
paramsChan,
aeadChanged,
)
}
} else {
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
s.cryptoStream,
s.connectionID,
hostname,
tlsConf,
transportParams,
paramsChan,
aeadChanged,
initialVersion,
s.config.Versions,
s.version,
)
} else {
s.cryptoSetup, err = newCryptoSetupClient(
s.cryptoStream,
hostname,
s.connectionID,
s.version,
tlsConf,
transportParams,
paramsChan,
aeadChanged,
initialVersion,
negotiatedVersions,
)
}
}
if err != nil {
return err
}
s.packer = newPacketPacker(s.connectionID, s.packer = newPacketPacker(s.connectionID,
initialPacketNumber,
s.cryptoSetup, s.cryptoSetup,
s.streamFramer, s.streamFramer,
s.perspective, s.perspective,
@ -604,7 +658,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
s.cryptoStream.Cancel(quicErr) s.cryptoStream.Cancel(quicErr)
s.streamsMap.CloseWithError(quicErr) s.streamsMap.CloseWithError(quicErr)
if closeErr.err == errCloseSessionForNewVersion { if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry {
return nil return nil
} }
@ -893,6 +947,10 @@ func (s *session) handshakeStatus() <-chan handshakeEvent {
return s.handshakeChan return s.handshakeChan
} }
func (s *session) getCryptoStream() cryptoStream {
return s.cryptoStream
}
func (s *session) GetVersion() protocol.VersionNumber { func (s *session) GetVersion() protocol.VersionNumber {
return s.version return s.version
} }

View file

@ -468,9 +468,6 @@ var _ = Describe("Session", func() {
}) })
It("handles CONNECTION_CLOSE frames", func() { It("handles CONNECTION_CLOSE frames", func() {
cryptoStream := mocks.NewMockStreamI(mockCtrl)
cryptoStream.EXPECT().Cancel(gomock.Any())
sess.cryptoStream = cryptoStream
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
@ -771,10 +768,15 @@ var _ = Describe("Session", func() {
}) })
Context("sending packets", func() { Context("sending packets", func() {
BeforeEach(func() {
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
})
It("sends ACK frames", func() { It("sends ACK frames", func() {
packetNumber := protocol.PacketNumber(0x035e) packetNumber := protocol.PacketNumber(0x035e)
sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
err := sess.sendPacket() Expect(err).ToNot(HaveOccurred())
err = sess.sendPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e}))))
@ -858,6 +860,7 @@ var _ = Describe("Session", func() {
BeforeEach(func() { BeforeEach(func() {
// a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet // a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet
sess.packer.packetNumberGenerator.next = 0x1337 + 10 sess.packer.packetNumberGenerator.next = 0x1337 + 10
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
sph = newMockSentPacketHandler().(*mockSentPacketHandler) sph = newMockSentPacketHandler().(*mockSentPacketHandler)
sess.sentPacketHandler = sph sess.sentPacketHandler = sph
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
@ -981,6 +984,7 @@ var _ = Describe("Session", func() {
}) })
It("retransmits RTO packets", func() { It("retransmits RTO packets", func() {
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
sess.sentPacketHandler.SetHandshakeComplete() sess.sentPacketHandler.SetHandshakeComplete()
n := protocol.PacketNumber(10) n := protocol.PacketNumber(10)
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
@ -1008,6 +1012,7 @@ var _ = Describe("Session", func() {
Context("scheduling sending", func() { Context("scheduling sending", func() {
BeforeEach(func() { BeforeEach(func() {
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
sess.processTransportParameters(&handshake.TransportParameters{ sess.processTransportParameters(&handshake.TransportParameters{
StreamFlowControlWindow: protocol.MaxByteCount, StreamFlowControlWindow: protocol.MaxByteCount,
ConnectionFlowControlWindow: protocol.MaxByteCount, ConnectionFlowControlWindow: protocol.MaxByteCount,
@ -1291,6 +1296,7 @@ var _ = Describe("Session", func() {
sess.handshakeComplete = true sess.handshakeComplete = true
sess.config.KeepAlive = true sess.config.KeepAlive = true
sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2)
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
go sess.run() go sess.run()
defer sess.Close(nil) defer sess.Close(nil)
var data []byte var data []byte
@ -1551,7 +1557,10 @@ var _ = Describe("Client Session", func() {
}) })
It("passes the diversification nonce to the cryptoSetup", func() { It("passes the diversification nonce to the cryptoSetup", func() {
go sess.run() go func() {
defer GinkgoRecover()
sess.run()
}()
hdr.PacketNumber = 5 hdr.PacketNumber = 5
hdr.DiversificationNonce = []byte("foobar") hdr.DiversificationNonce = []byte("foobar")
err := sess.handlePacketImpl(&receivedPacket{header: hdr}) err := sess.handlePacketImpl(&receivedPacket{header: hdr})

View file

@ -32,6 +32,11 @@ type streamI interface {
IsFlowControlBlocked() bool IsFlowControlBlocked() bool
} }
type cryptoStream interface {
streamI
SetReadOffset(protocol.ByteCount)
}
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
// //
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. // Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
@ -481,3 +486,11 @@ func (s *stream) IsFlowControlBlocked() bool {
func (s *stream) GetWindowUpdate() protocol.ByteCount { func (s *stream) GetWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate() return s.flowController.GetWindowUpdate()
} }
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *stream) SetReadOffset(offset protocol.ByteCount) {
s.readOffset = offset
s.frameQueue.readPosition = offset
}

View file

@ -266,6 +266,12 @@ var _ = Describe("Stream", func() {
Expect(onDataCalled).To(BeTrue()) Expect(onDataCalled).To(BeTrue())
}) })
It("sets the read offset", func() {
str.SetReadOffset(0x42)
Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42)))
Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42)))
})
Context("deadlines", func() { Context("deadlines", func() {
It("the deadline error has the right net.Error properties", func() { It("the deadline error has the right net.Error properties", func() {
Expect(errDeadline.Temporary()).To(BeTrue()) Expect(errDeadline.Temporary()).To(BeTrue())

View file

@ -325,4 +325,5 @@ func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.maxOutgoingStreams = limit m.maxOutgoingStreams = limit
m.openStreamOrErrCond.Broadcast()
} }