implement stateless handling of Initial packets for the TLS server

This commit is contained in:
Marten Seemann 2017-11-11 08:38:27 +08:00
parent 57c6f3ceb5
commit 25a6dc9654
36 changed files with 1617 additions and 724 deletions

150
client.go
View file

@ -10,6 +10,7 @@ import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
@ -23,13 +24,17 @@ type client struct {
hostname string
versionNegotiationChan chan struct{} // the versionNegotiationChan is closed as soon as the server accepted the suggested version
versionNegotiated bool // has version negotiation completed yet
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
config *Config
tls handshake.MintTLS // only used when using TLS
connectionID protocol.ConnectionID
initialVersion protocol.VersionNumber
version protocol.VersionNumber
session packetHandler
@ -91,7 +96,6 @@ func DialNonFWSecure(
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
@ -111,8 +115,9 @@ func DialNonFWSecure(
}
utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %s", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
go c.listen()
if err := c.establishSecureConnection(); err != nil {
if err := c.dial(); err != nil {
return nil, err
}
return c.session.(NonFWSession), nil
@ -177,25 +182,79 @@ func populateClientConfig(config *Config) *Config {
}
}
// establishSecureConnection returns as soon as the connection is secure (as opposed to forward-secure)
func (c *client) establishSecureConnection() error {
if err := c.createNewSession(c.version, nil); err != nil {
func (c *client) dial() error {
var err error
if c.version.UsesTLS() {
err = c.dialTLS()
} else {
err = c.dialGQUIC()
}
if err == errCloseSessionForNewVersion {
return c.dial()
}
return err
}
go c.listen()
func (c *client) dialGQUIC() error {
if err := c.createNewGQUICSession(); err != nil {
return err
}
return c.establishSecureConnection()
}
func (c *client) dialTLS() error {
csc := handshake.NewCryptoStreamConn(nil)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ServerName = c.hostname
c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
}
eh := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version)
if err := c.tls.SetExtensionHandler(eh); err != nil {
return err
}
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {
if err != handshake.ErrCloseSessionForRetry {
return err
}
utils.Infof("Received a Retry packet. Recreating session.")
if err := c.createNewTLSSession(eh.GetPeerParams(), c.version); err != nil {
return err
}
if err := c.establishSecureConnection(); err != nil {
return err
}
}
return nil
}
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error {
var runErr error
errorChan := make(chan struct{})
go func() {
// session.run() returns as soon as the session is closed
runErr = c.session.run()
if runErr == errCloseSessionForNewVersion {
// run the new session
runErr = c.session.run()
}
runErr = c.session.run() // returns as soon as the session is closed
close(errorChan)
utils.Infof("Connection %x closed.", c.connectionID)
if runErr != handshake.ErrCloseSessionForRetry && runErr != errCloseSessionForNewVersion {
c.conn.Close()
}
}()
// wait until the server accepts the QUIC version (or an error occurs)
@ -219,7 +278,8 @@ func (c *client) establishSecureConnection() error {
}
}
// Listen listens
// Listen listens on the underlying connection and passes packets on for handling.
// It returns when the connection is closed.
func (c *client) listen() {
var err error
@ -233,13 +293,15 @@ func (c *client) listen() {
n, addr, err = c.conn.Read(data)
if err != nil {
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
c.mutex.Lock()
if c.session != nil {
c.session.Close(err)
}
c.mutex.Unlock()
}
break
}
data = data[:n]
c.handlePacket(addr, data)
c.handlePacket(addr, data[:n])
}
}
@ -257,15 +319,16 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return
}
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
}
hdr.Raw = packet[:len(packet)-r.Len()]
c.mutex.Lock()
defer c.mutex.Unlock()
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && hdr.ConnectionID != c.connectionID {
return
}
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
@ -305,6 +368,8 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
close(c.versionNegotiationChan)
}
// TODO: validate packet number and connection ID on Retry packets (for IETF QUIC)
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
@ -323,15 +388,15 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
}
}
c.receivedVersionNegotiationPacket = true
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version
initialVersion := c.version
c.initialVersion = c.version
c.version = newVersion
var err error
c.connectionID, err = utils.GenerateConnectionID()
@ -339,17 +404,13 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
return err
}
utils.Infof("Switching to QUIC version %s. New connection ID: %x", newVersion, c.connectionID)
// create a new session and close the old one
// the new session must be created first to update client member variables
oldSession := c.session
defer oldSession.Close(errCloseSessionForNewVersion)
return c.createNewSession(initialVersion, hdr.SupportedVersions)
c.session.Close(errCloseSessionForNewVersion)
return nil
}
func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber) error {
var err error
utils.Debugf("createNewSession with initial version %s", initialVersion)
func (c *client) createNewGQUICSession() (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newClientSession(
c.conn,
c.hostname,
@ -357,8 +418,27 @@ func (c *client) createNewSession(initialVersion protocol.VersionNumber, negotia
c.connectionID,
c.tlsConf,
c.config,
initialVersion,
negotiatedVersions,
c.initialVersion,
c.negotiatedVersions,
)
return err
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.session, err = newTLSClientSession(
c.conn,
c.hostname,
c.version,
c.connectionID,
c.config,
c.tls,
paramsChan,
1,
)
return err
}

View file

@ -8,6 +8,7 @@ import (
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
@ -45,10 +46,9 @@ var _ = Describe("Client", func() {
msess, _ := newMockSession(nil, 0, 0, nil, nil, nil)
sess = msess.(*mockSession)
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
packetConn = &mockPacketConn{
addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234},
dataReadFrom: addr,
}
packetConn = newMockPacketConn()
packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
packetConn.dataReadFrom = addr
config = &Config{
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
}
@ -87,7 +87,7 @@ var _ = Describe("Client", func() {
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
) (packetHandler, error) {
Expect(conn.Write([]byte("fake CHLO"))).To(Succeed())
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
return sess, nil
}
origGenerateConnectionID = generateConnectionID
@ -101,7 +101,7 @@ var _ = Describe("Client", func() {
})
It("dials non-forward-secure", func() {
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -151,7 +151,7 @@ var _ = Describe("Client", func() {
})
It("Dial only returns after the handshake is complete", func() {
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -237,7 +237,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the connection to become secure", func() {
testErr := errors.New("early handshake error")
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -251,7 +251,7 @@ var _ = Describe("Client", func() {
It("returns an error that occurs while waiting for the handshake to complete", func() {
testErr := errors.New("late handshake error")
packetConn.dataToRead = acceptClientVersionPacket(cl.connectionID)
packetConn.dataToRead <- acceptClientVersionPacket(cl.connectionID)
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -330,10 +330,6 @@ var _ = Describe("Client", func() {
newVersion := protocol.VersionNumber(77)
Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion))
packetConn.dataToRead = wire.ComposeGQUICVersionNegotiation(
cl.connectionID,
[]protocol.VersionNumber{newVersion},
)
sessionChan := make(chan *mockSession)
handshakeChan := make(chan handshakeEvent)
newClientSession = func(
@ -348,10 +344,7 @@ var _ = Describe("Client", func() {
) (packetHandler, error) {
initialVersion = initialVersionP
negotiatedVersions = negotiatedVersionsP
// make the server accept the new version
if len(negotiatedVersionsP) > 0 {
packetConn.dataToRead = acceptClientVersionPacket(connectionID)
}
sess := &mockSession{
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
@ -364,18 +357,26 @@ var _ = Describe("Client", func() {
established := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cl.establishSecureConnection()
err := cl.dial()
Expect(err).ToNot(HaveOccurred())
close(established)
}()
go cl.listen()
actualInitialVersion := cl.version
var firstSession, secondSession *mockSession
Eventually(sessionChan).Should(Receive(&firstSession))
Eventually(sessionChan).Should(Receive(&secondSession))
packetConn.dataToRead <- wire.ComposeGQUICVersionNegotiation(
cl.connectionID,
[]protocol.VersionNumber{newVersion},
)
// it didn't pass the version negoation packet to the old session (since it has no payload)
Expect(firstSession.packetCount).To(BeZero())
Eventually(func() bool { return firstSession.closed }).Should(BeTrue())
Expect(firstSession.closeReason).To(Equal(errCloseSessionForNewVersion))
Expect(firstSession.packetCount).To(BeZero())
Eventually(sessionChan).Should(Receive(&secondSession))
// make the server accept the new version
packetConn.dataToRead <- acceptClientVersionPacket(secondSession.connectionID)
Consistently(func() bool { return secondSession.closed }).Should(BeFalse())
Expect(cl.connectionID).ToNot(BeEquivalentTo(0x1337))
Expect(negotiatedVersions).To(ContainElement(newVersion))
@ -398,20 +399,23 @@ var _ = Describe("Client", func() {
_ []protocol.VersionNumber,
) (packetHandler, error) {
atomic.AddUint32(&sessionCounter, 1)
return sess, nil
return &mockSession{
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
}, nil
}
go cl.establishSecureConnection()
go cl.dial()
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(1))
newVersion := protocol.VersionNumber(77)
Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion))
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2))
Eventually(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
newVersion = protocol.VersionNumber(78)
Expect(newVersion).ToNot(Equal(cl.version))
Expect(config.Versions).To(ContainElement(newVersion))
cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2))
Consistently(func() uint32 { return atomic.LoadUint32(&sessionCounter) }).Should(BeEquivalentTo(2))
})
It("errors if no matching version is found", func() {
@ -482,7 +486,7 @@ var _ = Describe("Client", func() {
Expect(sess.closed).To(BeFalse())
})
It("creates new sessions with the right parameters", func() {
It("creates new GQUIC sessions with the right parameters", func() {
closeErr := errors.New("peer doesn't reply")
c := make(chan struct{})
var cconn connection
@ -516,12 +520,84 @@ var _ = Describe("Client", func() {
Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(hostname).To(Equal("quic.clemente.io"))
Expect(version).To(Equal(cl.version))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
sess.Close(closeErr)
Eventually(dialed).Should(BeClosed())
})
It("creates new TLS sessions with the right parameters", func() {
config.Versions = []protocol.VersionNumber{protocol.VersionTLS}
c := make(chan struct{})
var cconn connection
var hostname string
var version protocol.VersionNumber
var conf *Config
newTLSClientSession = func(
connP connection,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
configP *Config,
tls handshake.MintTLS,
paramsChan <-chan handshake.TransportParameters,
_ protocol.PacketNumber,
) (packetHandler, error) {
cconn = connP
hostname = hostnameP
version = versionP
conf = configP
close(c)
return sess, nil
}
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
close(dialed)
}()
Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(hostname).To(Equal("quic.clemente.io"))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
sess.Close(errors.New("peer doesn't reply"))
Eventually(dialed).Should(BeClosed())
})
It("creates a new session when the server performs a retry", func() {
config.Versions = []protocol.VersionNumber{protocol.VersionTLS}
sessionChan := make(chan *mockSession)
newTLSClientSession = func(
connP connection,
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
configP *Config,
tls handshake.MintTLS,
paramsChan <-chan handshake.TransportParameters,
_ protocol.PacketNumber,
) (packetHandler, error) {
sess := &mockSession{
stopRunLoop: make(chan struct{}),
}
sessionChan <- sess
return sess, nil
}
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
close(dialed)
}()
var firstSession, secondSession *mockSession
Eventually(sessionChan).Should(Receive(&firstSession))
firstSession.Close(handshake.ErrCloseSessionForRetry)
Eventually(sessionChan).Should(Receive(&secondSession))
secondSession.Close(errors.New("stop test"))
Eventually(dialed).Should(BeClosed())
})
Context("handling packets", func() {
It("handles packets", func() {
ph := wire.Header{
@ -532,7 +608,7 @@ var _ = Describe("Client", func() {
b := &bytes.Buffer{}
err := ph.Write(b, protocol.PerspectiveServer, cl.version)
Expect(err).ToNot(HaveOccurred())
packetConn.dataToRead = b.Bytes()
packetConn.dataToRead <- b.Bytes()
Expect(sess.packetCount).To(BeZero())
stoppedListening := make(chan struct{})

View file

@ -2,7 +2,7 @@ package quic
import (
"bytes"
"io"
"errors"
"net"
"time"
@ -12,7 +12,7 @@ import (
type mockPacketConn struct {
addr net.Addr
dataToRead []byte
dataToRead chan []byte
dataReadFrom net.Addr
readErr error
dataWritten bytes.Buffer
@ -20,23 +20,34 @@ type mockPacketConn struct {
closed bool
}
func newMockPacketConn() *mockPacketConn {
return &mockPacketConn{
dataToRead: make(chan []byte, 1000),
}
}
func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
if c.readErr != nil {
return 0, nil, c.readErr
}
if c.dataToRead == nil { // block if there's no data
time.Sleep(time.Hour)
return 0, nil, io.EOF
data, ok := <-c.dataToRead
if !ok {
return 0, nil, errors.New("connection closed")
}
n := copy(b, c.dataToRead)
c.dataToRead = nil
n := copy(b, data)
return n, c.dataReadFrom, nil
}
func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
c.dataWrittenTo = addr
return c.dataWritten.Write(b)
}
func (c *mockPacketConn) Close() error { c.closed = true; return nil }
func (c *mockPacketConn) Close() error {
if !c.closed {
close(c.dataToRead)
}
c.closed = true
return nil
}
func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr }
func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") }
func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") }
@ -53,7 +64,7 @@ var _ = Describe("Connection", func() {
IP: net.IPv4(192, 168, 100, 200),
Port: 1337,
}
packetConn = &mockPacketConn{}
packetConn = newMockPacketConn()
c = &conn{
currentAddr: addr,
pconn: packetConn,
@ -68,7 +79,7 @@ var _ = Describe("Connection", func() {
})
It("reads", func() {
packetConn.dataToRead = []byte("foo")
packetConn.dataToRead <- []byte("foo")
packetConn.dataReadFrom = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1336}
p := make([]byte, 10)
n, raddr, err := c.Read(p)

View file

@ -7,33 +7,33 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils"
)
type cookieHandler struct {
type CookieHandler struct {
callback func(net.Addr, *Cookie) bool
cookieGenerator *CookieGenerator
}
var _ mint.CookieHandler = &cookieHandler{}
var _ mint.CookieHandler = &CookieHandler{}
func newCookieHandler(callback func(net.Addr, *Cookie) bool) (*cookieHandler, error) {
func NewCookieHandler(callback func(net.Addr, *Cookie) bool) (*CookieHandler, error) {
cookieGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
}
return &cookieHandler{
return &CookieHandler{
callback: callback,
cookieGenerator: cookieGenerator,
}, nil
}
func (h *cookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
if h.callback(conn.RemoteAddr(), nil) {
return nil, nil
}
return h.cookieGenerator.NewToken(conn.RemoteAddr())
}
func (h *cookieHandler) Validate(conn *mint.Conn, token []byte) bool {
func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
data, err := h.cookieGenerator.DecodeToken(token)
if err != nil {
utils.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())

View file

@ -2,6 +2,7 @@ package handshake
import (
"net"
"time"
"github.com/bifurcation/mint"
@ -9,22 +10,37 @@ import (
. "github.com/onsi/gomega"
)
type mockConn struct {
remoteAddr net.Addr
}
var _ net.Conn = &mockConn{}
func (c *mockConn) Read([]byte) (int, error) { panic("not implemented") }
func (c *mockConn) Write([]byte) (int, error) { panic("not implemented") }
func (c *mockConn) Close() error { panic("not implemented") }
func (c *mockConn) LocalAddr() net.Addr { panic("not implemented") }
func (c *mockConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *mockConn) SetReadDeadline(time.Time) error { panic("not implemented") }
func (c *mockConn) SetWriteDeadline(time.Time) error { panic("not implemented") }
func (c *mockConn) SetDeadline(time.Time) error { panic("not implemented") }
var callbackReturn bool
var mockCallback = func(net.Addr, *Cookie) bool {
return callbackReturn
}
var _ = Describe("Cookie Handler", func() {
var ch *cookieHandler
var ch *CookieHandler
var conn *mint.Conn
BeforeEach(func() {
callbackReturn = false
var err error
ch, err = newCookieHandler(mockCallback)
ch, err = NewCookieHandler(mockCallback)
Expect(err).ToNot(HaveOccurred())
addr := &net.UDPAddr{IP: net.IPv4(42, 43, 44, 45), Port: 46}
conn = mint.NewConn(&fakeConn{remoteAddr: addr}, &mint.Config{}, false)
conn = mint.NewConn(&mockConn{remoteAddr: addr}, &mint.Config{}, false)
})
It("generates and validates a token", func() {

View file

@ -381,10 +381,6 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) {
h.divNonceChan <- data
}
func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {

View file

@ -458,10 +458,6 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType {
panic("not needed for cryptoSetupServer")
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")

View file

@ -1,10 +1,9 @@
package handshake
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"github.com/bifurcation/mint"
@ -12,6 +11,9 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
// KeyDerivationFunction is used for key derivation
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
@ -20,68 +22,31 @@ type cryptoSetupTLS struct {
perspective protocol.Perspective
tls mintTLS
conn *fakeConn
nextPacketType protocol.PacketType
keyDerivation KeyDerivationFunction
nullAEAD crypto.AEAD
aead crypto.AEAD
tls MintTLS
cryptoStream *CryptoStreamConn
aeadChanged chan<- protocol.EncryptionLevel
}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
tlsConfig *tls.Config,
remoteAddr net.Addr,
params *TransportParameters,
paramsChan chan<- TransportParameters,
tls MintTLS,
cryptoStream *CryptoStreamConn,
nullAEAD crypto.AEAD,
aeadChanged chan<- protocol.EncryptionLevel,
checkCookie func(net.Addr, *Cookie) bool,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
) (CryptoSetup, error) {
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveServer)
if err != nil {
return nil, err
}
mintConf.RequireCookie = true
mintConf.CookieHandler, err = newCookieHandler(checkCookie)
if err != nil {
return nil, err
}
mintConf.CookieProtector, err = mint.NewDefaultCookieProtector()
if err != nil {
return nil, err
}
conn := &fakeConn{
stream: cryptoStream,
pers: protocol.PerspectiveServer,
remoteAddr: remoteAddr,
}
mintConn := mint.Server(conn, mintConf)
eh := newExtensionHandlerServer(params, paramsChan, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
}
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
) CryptoSetup {
return &cryptoSetupTLS{
perspective: protocol.PerspectiveServer,
tls: &mintController{mintConn},
conn: conn,
tls: tls,
cryptoStream: cryptoStream,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
}, nil
}
}
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
@ -89,60 +54,44 @@ func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
hostname string,
tlsConfig *tls.Config,
params *TransportParameters,
paramsChan chan<- TransportParameters,
aeadChanged chan<- protocol.EncryptionLevel,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
tls MintTLS,
version protocol.VersionNumber,
) (CryptoSetup, error) {
mintConf, err := tlsToMintConfig(tlsConfig, protocol.PerspectiveClient)
if err != nil {
return nil, err
}
mintConf.ServerName = hostname
conn := &fakeConn{
stream: cryptoStream,
pers: protocol.PerspectiveClient,
}
mintConn := mint.Client(conn, mintConf)
eh := newExtensionHandlerClient(params, paramsChan, initialVersion, supportedVersions, version)
if err := mintConn.SetExtensionHandler(eh); err != nil {
return nil, err
}
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupTLS{
conn: conn,
perspective: protocol.PerspectiveClient,
tls: &mintController{mintConn},
tls: tls,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
aeadChanged: aeadChanged,
nextPacketType: protocol.PacketTypeInitial,
}, nil
}
func (h *cryptoSetupTLS) HandleCryptoStream() error {
handshakeLoop:
for {
switch alert := h.tls.Handshake(); alert {
case mint.AlertStatelessRetry:
case mint.AlertNoAlert: // handshake complete
break handshakeLoop
case mint.AlertWouldBlock:
h.determineNextPacketType()
if err := h.conn.Continue(); err != nil {
if h.perspective == protocol.PerspectiveServer {
// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
// send out that data now
if _, err := h.cryptoStream.Flush(); err != nil {
return err
}
default:
}
handshakeLoop:
for {
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
}
switch h.tls.State() {
case mint.StateClientStart: // this happens if a stateless retry is performed
return ErrCloseSessionForRetry
case mint.StateClientConnected, mint.StateServerConnected:
break handshakeLoop
}
}
aead, err := h.keyDerivation(h.tls, h.perspective)
@ -209,35 +158,6 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) determineNextPacketType() error {
h.mutex.Lock()
defer h.mutex.Unlock()
state := h.tls.State().HandshakeState
if h.perspective == protocol.PerspectiveServer {
switch state {
case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest
h.nextPacketType = protocol.PacketTypeRetry
case "ServerStateWaitFinished":
h.nextPacketType = protocol.PacketTypeHandshake
default:
// TODO: accept 0-RTT data
return fmt.Errorf("Unexpected handshake state: %s", state)
}
return nil
}
// client
if state != "ClientStateWaitSH" {
h.nextPacketType = protocol.PacketTypeHandshake
}
return nil
}
func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.nextPacketType
}
func (h *cryptoSetupTLS) DiversificationNonce() []byte {
panic("diversification nonce not needed for TLS")
}

View file

@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"errors"
"fmt"
@ -10,7 +9,6 @@ import (
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -23,52 +21,33 @@ func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, e
var _ = Describe("TLS Crypto Setup", func() {
var (
cs *cryptoSetupTLS
paramsChan chan TransportParameters
aeadChanged chan protocol.EncryptionLevel
)
BeforeEach(func() {
paramsChan = make(chan TransportParameters)
aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupTLSServer(
cs = NewCryptoSetupTLSServer(
nil,
1,
testdata.GetTLSConfig(),
nil,
&TransportParameters{},
paramsChan,
NewCryptoStreamConn(nil),
nil, // AEAD
aeadChanged,
nil,
nil,
protocol.VersionTLS,
)
Expect(err).ToNot(HaveOccurred())
cs = csInt.(*cryptoSetupTLS)
).(*cryptoSetupTLS)
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
})
It("errors when the handshake fails", func() {
alert := mint.AlertBadRecordMAC
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(alert)
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(alert)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)))
})
It("continues shaking hands when mint says that it would block", func() {
cs.conn.stream = &bytes.Buffer{}
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{})
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
})
It("derives keys", func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
@ -76,64 +55,22 @@ var _ = Describe("TLS Crypto Setup", func() {
Expect(aeadChanged).To(BeClosed())
})
Context("determining the packet type", func() {
Context("for the client", func() {
var csClient *cryptoSetupTLS
BeforeEach(func() {
csInt, err := NewCryptoSetupTLSClient(
nil,
1,
"quic.clemente.io",
testdata.GetTLSConfig(),
&TransportParameters{},
paramsChan,
aeadChanged,
protocol.VersionTLS,
[]protocol.VersionNumber{protocol.VersionTLS},
protocol.VersionTLS,
)
It("handshakes until it is connected", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerNegotiated).Times(9)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
csClient = csInt.(*cryptoSetupTLS)
csClient.tls = mockhandshake.NewMockmintTLS(mockCtrl)
})
It("sends a Client Initial first", func() {
Expect(csClient.GetNextPacketType()).To(Equal(protocol.PacketTypeInitial))
})
It("sends a Handshake packet after the server sent a Server Hello", func() {
csClient.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ClientStateWaitEE"})
err := csClient.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
})
})
Context("for the server", func() {
BeforeEach(func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
})
It("sends a Stateless Retry packet", func() {
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateStart"})
err := cs.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeRetry))
})
It("sends Handshake packet", func() {
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateWaitFinished"})
err := cs.determineNextPacketType()
Expect(err).ToNot(HaveOccurred())
Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeHandshake))
})
})
Expect(aeadChanged).To(Receive())
})
Context("escalating crypto", func() {
doHandshake := func() {
cs.tls = mockhandshake.NewMockmintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateServerConnected)
cs.keyDerivation = mockKeyDerivation
err := cs.HandleCryptoStream()
Expect(err).ToNot(HaveOccurred())
@ -240,3 +177,33 @@ var _ = Describe("TLS Crypto Setup", func() {
})
})
})
var _ = Describe("TLS Crypto Setup, for the client", func() {
var (
cs *cryptoSetupTLS
aeadChanged chan protocol.EncryptionLevel
)
BeforeEach(func() {
aeadChanged = make(chan protocol.EncryptionLevel, 2)
csInt, err := NewCryptoSetupTLSClient(
nil,
0,
"quic.clemente.io",
aeadChanged,
nil, // mintTLS
protocol.VersionTLS,
)
Expect(err).ToNot(HaveOccurred())
cs = csInt.(*cryptoSetupTLS)
})
It("returns when a retry is performed", func() {
cs.tls = mockhandshake.NewMockMintTLS(mockCtrl)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert)
cs.tls.(*mockhandshake.MockMintTLS).EXPECT().State().Return(mint.StateClientStart)
err := cs.HandleCryptoStream()
Expect(err).To(MatchError(ErrCloseSessionForRetry))
})
})

View file

@ -0,0 +1,101 @@
package handshake
import (
"bytes"
"io"
"net"
"time"
)
// The CryptoStreamConn is used as the net.Conn passed to mint.
// It has two operating modes:
// 1. It can read and write to bytes.Buffers.
// 2. It can use a quic.Stream for reading and writing.
// The buffer-mode is only used by the server, in order to statelessly handle retries.
type CryptoStreamConn struct {
remoteAddr net.Addr
// the buffers are used before the session is initialized
readBuf bytes.Buffer
writeBuf bytes.Buffer
// stream will be set once the session is initialized
stream io.ReadWriter
}
var _ net.Conn = &CryptoStreamConn{}
// NewCryptoStreamConn creates a new CryptoStreamConn
func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
return &CryptoStreamConn{remoteAddr: remoteAddr}
}
func (c *CryptoStreamConn) Read(b []byte) (int, error) {
if c.stream != nil {
return c.stream.Read(b)
}
return c.readBuf.Read(b)
}
// AddDataForReading adds data to the read buffer.
// This data will ONLY be read when the stream has not been set.
func (c *CryptoStreamConn) AddDataForReading(data []byte) {
c.readBuf.Write(data)
}
func (c *CryptoStreamConn) Write(p []byte) (int, error) {
if c.stream != nil {
return c.stream.Write(p)
}
return c.writeBuf.Write(p)
}
// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
func (c *CryptoStreamConn) GetDataForWriting() []byte {
defer c.writeBuf.Reset()
data := make([]byte, c.writeBuf.Len())
copy(data, c.writeBuf.Bytes())
return data
}
// SetStream sets the stream.
// After setting the stream, the read and write buffer won't be used any more.
func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
c.stream = stream
}
// Flush copies the contents of the write buffer to the stream
func (c *CryptoStreamConn) Flush() (int, error) {
n, err := io.Copy(c.stream, &c.writeBuf)
return int(n), err
}
// Close is not implemented
func (c *CryptoStreamConn) Close() error {
return nil
}
// LocalAddr is not implemented
func (c *CryptoStreamConn) LocalAddr() net.Addr {
return nil
}
// RemoteAddr returns the remote address
func (c *CryptoStreamConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
// SetReadDeadline is not implemented
func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
return nil
}
// SetWriteDeadline is not implemented
func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// SetDeadline is not implemented
func (c *CryptoStreamConn) SetDeadline(time.Time) error {
return nil
}

View file

@ -0,0 +1,67 @@
package handshake
import (
"bytes"
"net"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("CryptoStreamConn", func() {
var (
csc *CryptoStreamConn
remoteAddr net.Addr
)
BeforeEach(func() {
remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
csc = NewCryptoStreamConn(remoteAddr)
})
It("reads from the read buffer, when no stream is set", func() {
csc.AddDataForReading([]byte("foobar"))
data := make([]byte, 4)
n, err := csc.Read(data)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(4))
Expect(data).To(Equal([]byte("foob")))
})
It("writes to the write buffer, when no stream is set", func() {
csc.Write([]byte("foo"))
Expect(csc.GetDataForWriting()).To(Equal([]byte("foo")))
csc.Write([]byte("bar"))
Expect(csc.GetDataForWriting()).To(Equal([]byte("bar")))
})
It("reads from the stream, if available", func() {
csc.stream = &bytes.Buffer{}
csc.stream.Write([]byte("foobar"))
data := make([]byte, 3)
n, err := csc.Read(data)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(Equal(3))
Expect(data).To(Equal([]byte("foo")))
})
It("writes to the stream, if available", func() {
stream := &bytes.Buffer{}
csc.SetStream(stream)
csc.Write([]byte("foobar"))
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
It("returns the remote address", func() {
Expect(csc.RemoteAddr()).To(Equal(remoteAddr))
})
It("has unimplemented methods", func() {
Expect(csc.Close()).ToNot(HaveOccurred())
Expect(csc.SetDeadline(time.Time{})).ToNot(HaveOccurred())
Expect(csc.SetReadDeadline(time.Time{})).ToNot(HaveOccurred())
Expect(csc.SetWriteDeadline(time.Time{})).ToNot(HaveOccurred())
Expect(csc.LocalAddr()).To(BeNil())
})
})

View file

@ -1,6 +1,10 @@
package handshake
import (
"io"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
@ -10,6 +14,26 @@ type Sealer interface {
Overhead() int
}
// A TLSExtensionHandler sends and received the QUIC TLS extension.
// It provides the parameters sent by the peer on a channel.
type TLSExtensionHandler interface {
Send(mint.HandshakeType, *mint.ExtensionList) error
Receive(mint.HandshakeType, *mint.ExtensionList) error
GetPeerParams() <-chan TransportParameters
}
// MintTLS combines some methods needed to interact with mint.
type MintTLS interface {
crypto.TLSExporter
// additional methods
Handshake() mint.Alert
State() mint.State
SetCryptoStream(io.ReadWriter)
SetExtensionHandler(mint.AppExtensionHandler) error
}
// CryptoSetup is a crypto setup
type CryptoSetup interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
@ -17,7 +41,6 @@ type CryptoSetup interface {
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) // only needed for cryptoSetupClient
GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)

View file

@ -1,127 +0,0 @@
package handshake
import (
"bytes"
gocrypto "crypto"
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
mconf := &mint.Config{
NonBlocking: true,
CipherSuites: []mint.CipherSuite{
mint.TLS_AES_128_GCM_SHA256,
mint.TLS_AES_256_GCM_SHA384,
},
}
if tlsConf != nil {
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
for i, certChain := range tlsConf.Certificates {
mconf.Certificates[i] = &mint.Certificate{
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
}
for j, cert := range certChain.Certificate {
c, err := x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
mconf.Certificates[i].Chain[j] = c
}
}
}
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
return nil, err
}
return mconf, nil
}
type mintTLS interface {
// These two methods are the same as the crypto.TLSExporter interface.
// Cannot use embedding here, because mockgen source mode refuses to generate mocks then.
GetCipherSuite() mint.CipherSuiteParams
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
// additional methods
Handshake() mint.Alert
State() mint.ConnectionState
}
var _ crypto.TLSExporter = (mintTLS)(nil)
type mintController struct {
conn *mint.Conn
}
var _ mintTLS = &mintController{}
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
return mc.conn.State().CipherSuite
}
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
return mc.conn.ComputeExporter(label, context, keyLength)
}
func (mc *mintController) Handshake() mint.Alert {
return mc.conn.Handshake()
}
func (mc *mintController) State() mint.ConnectionState {
return mc.conn.State()
}
// mint expects a net.Conn, but we're doing the handshake on a stream
// so we wrap a stream such that implements a net.Conn
type fakeConn struct {
stream io.ReadWriter
pers protocol.Perspective
remoteAddr net.Addr
blockRead bool
writeBuffer bytes.Buffer
}
var _ net.Conn = &fakeConn{}
func (c *fakeConn) Read(b []byte) (int, error) {
if c.blockRead { // this causes mint.Conn.Handshake() to return a mint.AlertWouldBlock
return 0, nil
}
c.blockRead = true // block the next Read call
return c.stream.Read(b)
}
func (c *fakeConn) Write(p []byte) (int, error) {
if c.pers == protocol.PerspectiveClient {
return c.stream.Write(p)
}
// Buffer all writes by the server.
// Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data.
return c.writeBuffer.Write(p)
}
func (c *fakeConn) Continue() error {
c.blockRead = false
if c.pers == protocol.PerspectiveClient {
return nil
}
// write all contents of the write buffer to the stream.
_, err := c.stream.Write(c.writeBuffer.Bytes())
c.writeBuffer.Reset()
return err
}
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) LocalAddr() net.Addr { return nil }
func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *fakeConn) SetReadDeadline(time.Time) error { return nil }
func (c *fakeConn) SetWriteDeadline(time.Time) error { return nil }
func (c *fakeConn) SetDeadline(time.Time) error { return nil }

View file

@ -1,72 +0,0 @@
package handshake
import (
"bytes"
"net"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Fake Conn", func() {
var (
c *fakeConn
stream *bytes.Buffer
)
BeforeEach(func() {
stream = &bytes.Buffer{}
c = &fakeConn{stream: stream}
})
Context("Reading", func() {
It("doesn't return any new data after one Read call", func() {
stream.Write([]byte("foobar"))
b := make([]byte, 3)
_, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("foo")))
n, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(n).To(BeZero())
})
It("allows more Read calls after unblocking", func() {
stream.Write([]byte("foobar"))
b := make([]byte, 3)
_, err := c.Read(b)
Expect(err).ToNot(HaveOccurred())
err = c.Continue()
Expect(err).ToNot(HaveOccurred())
_, err = c.Read(b)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(Equal([]byte("bar")))
})
})
Context("Writing", func() {
It("writes directly when acting as a client", func() {
c.pers = protocol.PerspectiveClient
_, err := c.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
It("only writes after flushing when acting as a server", func() {
c.pers = protocol.PerspectiveServer
_, err := c.Write([]byte("foobar"))
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(BeEmpty())
err = c.Continue()
Expect(err).ToNot(HaveOccurred())
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
})
})
It("returns its remote address", func() {
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
c.remoteAddr = addr
Expect(c.RemoteAddr()).To(Equal(addr))
})
})

View file

@ -13,8 +13,8 @@ import (
)
type extensionHandlerClient struct {
params *TransportParameters
paramsChan chan<- TransportParameters
ourParams *TransportParameters
paramsChan chan TransportParameters
initialVersion protocol.VersionNumber
supportedVersions []protocol.VersionNumber
@ -22,16 +22,17 @@ type extensionHandlerClient struct {
}
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
var _ TLSExtensionHandler = &extensionHandlerClient{}
func newExtensionHandlerClient(
func NewExtensionHandlerClient(
params *TransportParameters,
paramsChan chan<- TransportParameters,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
) *extensionHandlerClient {
) TLSExtensionHandler {
paramsChan := make(chan TransportParameters, 1)
return &extensionHandlerClient{
params: params,
ourParams: params,
paramsChan: paramsChan,
initialVersion: initialVersion,
supportedVersions: supportedVersions,
@ -46,7 +47,7 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
data, err := syntax.Marshal(clientHelloTransportParameters{
InitialVersion: uint32(h.initialVersion),
Parameters: h.params.getTransportParameters(),
Parameters: h.ourParams.getTransportParameters(),
})
if err != nil {
return err
@ -123,3 +124,7 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
h.paramsChan <- *params
return nil
}
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View file

@ -15,13 +15,10 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
var (
handler *extensionHandlerClient
el mint.ExtensionList
paramsChan chan TransportParameters
)
BeforeEach(func() {
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
paramsChan = make(chan TransportParameters, 1)
handler = newExtensionHandlerClient(&TransportParameters{}, paramsChan, protocol.VersionWhatever, nil, protocol.VersionWhatever)
handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever).(*extensionHandlerClient)
el = make(mint.ExtensionList, 0)
})
@ -81,7 +78,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el)
Expect(err).ToNot(HaveOccurred())
var params TransportParameters
Expect(paramsChan).To(Receive(&params))
Expect(handler.GetPeerParams()).To(Receive(&params))
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
})

View file

@ -14,26 +14,27 @@ import (
)
type extensionHandlerServer struct {
params *TransportParameters
paramsChan chan<- TransportParameters
ourParams *TransportParameters
paramsChan chan TransportParameters
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
}
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
var _ TLSExtensionHandler = &extensionHandlerServer{}
func newExtensionHandlerServer(
func NewExtensionHandlerServer(
params *TransportParameters,
paramsChan chan<- TransportParameters,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
) *extensionHandlerServer {
) TLSExtensionHandler {
paramsChan := make(chan TransportParameters, 1)
return &extensionHandlerServer{
params: params,
ourParams: params,
paramsChan: paramsChan,
version: version,
supportedVersions: supportedVersions,
version: version,
}
}
@ -43,7 +44,7 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
}
transportParams := append(
h.params.getTransportParameters(),
h.ourParams.getTransportParameters(),
// TODO(#855): generate a real token
transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
)
@ -105,3 +106,7 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
h.paramsChan <- *params
return nil
}
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View file

@ -22,13 +22,10 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
var (
handler *extensionHandlerServer
el mint.ExtensionList
paramsChan chan TransportParameters
)
BeforeEach(func() {
// use a buffered channel here, so that we don't have to receive concurrently when parsing a message
paramsChan = make(chan TransportParameters, 1)
handler = newExtensionHandlerServer(&TransportParameters{}, paramsChan, nil, protocol.VersionWhatever)
handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever).(*extensionHandlerServer)
el = make(mint.ExtensionList, 0)
})
@ -91,7 +88,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() {
err := handler.Receive(mint.HandshakeTypeClientHello, &el)
Expect(err).ToNot(HaveOccurred())
var params TransportParameters
Expect(paramsChan).To(Receive(&params))
Expect(handler.GetPeerParams()).To(Receive(&params))
Expect(params.StreamFlowControlWindow).To(BeEquivalentTo(0x11223344))
})

View file

@ -1,6 +1,7 @@
package mocks
//go:generate sh -c "mockgen -source=../handshake/mint_utils.go -package mockhandshake -destination handshake/mint_tls.go"
//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS"
//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"

View file

@ -1,83 +1,106 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ../handshake/mint_utils.go
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS)
package mockhandshake
import (
io "io"
reflect "reflect"
mint "github.com/bifurcation/mint"
gomock "github.com/golang/mock/gomock"
)
// MockmintTLS is a mock of mintTLS interface
type MockmintTLS struct {
// MockMintTLS is a mock of MintTLS interface
type MockMintTLS struct {
ctrl *gomock.Controller
recorder *MockmintTLSMockRecorder
recorder *MockMintTLSMockRecorder
}
// MockmintTLSMockRecorder is the mock recorder for MockmintTLS
type MockmintTLSMockRecorder struct {
mock *MockmintTLS
// MockMintTLSMockRecorder is the mock recorder for MockMintTLS
type MockMintTLSMockRecorder struct {
mock *MockMintTLS
}
// NewMockmintTLS creates a new mock instance
func NewMockmintTLS(ctrl *gomock.Controller) *MockmintTLS {
mock := &MockmintTLS{ctrl: ctrl}
mock.recorder = &MockmintTLSMockRecorder{mock}
// NewMockMintTLS creates a new mock instance
func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS {
mock := &MockMintTLS{ctrl: ctrl}
mock.recorder = &MockMintTLSMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (_m *MockmintTLS) EXPECT() *MockmintTLSMockRecorder {
func (_m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder {
return _m.recorder
}
// GetCipherSuite mocks base method
func (_m *MockmintTLS) GetCipherSuite() mint.CipherSuiteParams {
ret := _m.ctrl.Call(_m, "GetCipherSuite")
ret0, _ := ret[0].(mint.CipherSuiteParams)
return ret0
}
// GetCipherSuite indicates an expected call of GetCipherSuite
func (_mr *MockmintTLSMockRecorder) GetCipherSuite() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockmintTLS)(nil).GetCipherSuite))
}
// ComputeExporter mocks base method
func (_m *MockmintTLS) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
ret := _m.ctrl.Call(_m, "ComputeExporter", label, context, keyLength)
func (_m *MockMintTLS) ComputeExporter(_param0 string, _param1 []byte, _param2 int) ([]byte, error) {
ret := _m.ctrl.Call(_m, "ComputeExporter", _param0, _param1, _param2)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ComputeExporter indicates an expected call of ComputeExporter
func (_mr *MockmintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockmintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
func (_mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2)
}
// GetCipherSuite mocks base method
func (_m *MockMintTLS) GetCipherSuite() mint.CipherSuiteParams {
ret := _m.ctrl.Call(_m, "GetCipherSuite")
ret0, _ := ret[0].(mint.CipherSuiteParams)
return ret0
}
// GetCipherSuite indicates an expected call of GetCipherSuite
func (_mr *MockMintTLSMockRecorder) GetCipherSuite() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetCipherSuite", reflect.TypeOf((*MockMintTLS)(nil).GetCipherSuite))
}
// Handshake mocks base method
func (_m *MockmintTLS) Handshake() mint.Alert {
func (_m *MockMintTLS) Handshake() mint.Alert {
ret := _m.ctrl.Call(_m, "Handshake")
ret0, _ := ret[0].(mint.Alert)
return ret0
}
// Handshake indicates an expected call of Handshake
func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake))
func (_mr *MockMintTLSMockRecorder) Handshake() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake))
}
// SetCryptoStream mocks base method
func (_m *MockMintTLS) SetCryptoStream(_param0 io.ReadWriter) {
_m.ctrl.Call(_m, "SetCryptoStream", _param0)
}
// SetCryptoStream indicates an expected call of SetCryptoStream
func (_mr *MockMintTLSMockRecorder) SetCryptoStream(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetCryptoStream", reflect.TypeOf((*MockMintTLS)(nil).SetCryptoStream), arg0)
}
// SetExtensionHandler mocks base method
func (_m *MockMintTLS) SetExtensionHandler(_param0 mint.AppExtensionHandler) error {
ret := _m.ctrl.Call(_m, "SetExtensionHandler", _param0)
ret0, _ := ret[0].(error)
return ret0
}
// SetExtensionHandler indicates an expected call of SetExtensionHandler
func (_mr *MockMintTLSMockRecorder) SetExtensionHandler(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetExtensionHandler", reflect.TypeOf((*MockMintTLS)(nil).SetExtensionHandler), arg0)
}
// State mocks base method
func (_m *MockmintTLS) State() mint.ConnectionState {
func (_m *MockMintTLS) State() mint.State {
ret := _m.ctrl.Call(_m, "State")
ret0, _ := ret[0].(mint.ConnectionState)
ret0, _ := ret[0].(mint.State)
return ret0
}
// State indicates an expected call of State
func (_mr *MockmintTLSMockRecorder) State() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockmintTLS)(nil).State))
func (_mr *MockMintTLSMockRecorder) State() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockMintTLS)(nil).State))
}

View file

@ -0,0 +1,71 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler)
package mocks
import (
reflect "reflect"
mint "github.com/bifurcation/mint"
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
)
// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface
type MockTLSExtensionHandler struct {
ctrl *gomock.Controller
recorder *MockTLSExtensionHandlerMockRecorder
}
// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler
type MockTLSExtensionHandlerMockRecorder struct {
mock *MockTLSExtensionHandler
}
// NewMockTLSExtensionHandler creates a new mock instance
func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler {
mock := &MockTLSExtensionHandler{ctrl: ctrl}
mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (_m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder {
return _m.recorder
}
// GetPeerParams mocks base method
func (_m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters {
ret := _m.ctrl.Call(_m, "GetPeerParams")
ret0, _ := ret[0].(<-chan handshake.TransportParameters)
return ret0
}
// GetPeerParams indicates an expected call of GetPeerParams
func (_mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams))
}
// Receive mocks base method
func (_m *MockTLSExtensionHandler) Receive(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error {
ret := _m.ctrl.Call(_m, "Receive", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
// Receive indicates an expected call of Receive
func (_mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1)
}
// Send mocks base method
func (_m *MockTLSExtensionHandler) Send(_param0 mint.HandshakeType, _param1 *mint.ExtensionList) error {
ret := _m.ctrl.Call(_m, "Send", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
// Send indicates an expected call of Send
func (_mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1)
}

150
mint_utils.go Normal file
View file

@ -0,0 +1,150 @@
package quic
import (
"bytes"
gocrypto "crypto"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type mintController struct {
csc *handshake.CryptoStreamConn
conn *mint.Conn
}
var _ handshake.MintTLS = &mintController{}
func newMintController(
csc *handshake.CryptoStreamConn,
mconf *mint.Config,
pers protocol.Perspective,
) handshake.MintTLS {
var conn *mint.Conn
if pers == protocol.PerspectiveClient {
conn = mint.Client(csc, mconf)
} else {
conn = mint.Server(csc, mconf)
}
return &mintController{
csc: csc,
conn: conn,
}
}
func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
return mc.conn.State().CipherSuite
}
func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
return mc.conn.ComputeExporter(label, context, keyLength)
}
func (mc *mintController) Handshake() mint.Alert {
return mc.conn.Handshake()
}
func (mc *mintController) State() mint.State {
return mc.conn.State().HandshakeState
}
func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
mc.csc.SetStream(stream)
}
func (mc *mintController) SetExtensionHandler(h mint.AppExtensionHandler) error {
return mc.conn.SetExtensionHandler(h)
}
func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
mconf := &mint.Config{
NonBlocking: true,
CipherSuites: []mint.CipherSuite{
mint.TLS_AES_128_GCM_SHA256,
mint.TLS_AES_256_GCM_SHA384,
},
}
if tlsConf != nil {
mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
for i, certChain := range tlsConf.Certificates {
mconf.Certificates[i] = &mint.Certificate{
Chain: make([]*x509.Certificate, len(certChain.Certificate)),
PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
}
for j, cert := range certChain.Certificate {
c, err := x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
mconf.Certificates[i].Chain[j] = c
}
}
}
if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
return nil, err
}
return mconf, nil
}
// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets
// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0.
func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.StreamFrame, error) {
unpacker := &packetUnpacker{aead: &nullAEAD{aead}, version: version}
packet, err := unpacker.Unpack(hdr.Raw, hdr, data)
if err != nil {
return nil, err
}
var frame *wire.StreamFrame
for _, f := range packet.frames {
var ok bool
frame, ok = f.(*wire.StreamFrame)
if ok {
break
}
}
if frame == nil {
return nil, errors.New("Packet doesn't contain a STREAM_FRAME")
}
// We don't need a check for the stream ID here.
// The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream.
if frame.Offset != 0 {
return nil, errors.New("received stream data with non-zero offset")
}
if utils.Debug() {
utils.Debugf("<- Reading packet 0x%x (%d bytes) for connection %x", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.ConnectionID)
hdr.Log()
wire.LogFrame(frame, false)
}
return frame, nil
}
// packUnencryptedPacket provides a low-overhead way to pack a packet.
// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) {
raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw)
if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
if err := sf.Write(buffer, hdr.Version); err != nil {
return nil, err
}
raw = raw[0:buffer.Len()]
_ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+aead.Overhead()]
if utils.Debug() {
utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted)
hdr.Log()
wire.LogFrame(sf, true)
}
return raw, nil
}

111
mint_utils_test.go Normal file
View file

@ -0,0 +1,111 @@
package quic
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Packing and unpacking Initial packets", func() {
var aead crypto.AEAD
connID := protocol.ConnectionID(0x1337)
ver := protocol.VersionTLS
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
PacketNumber: 0x42,
ConnectionID: connID,
Version: ver,
}
BeforeEach(func() {
var err error
aead, err = crypto.NewNullAEAD(protocol.PerspectiveServer, connID, ver)
Expect(err).ToNot(HaveOccurred())
// set hdr.Raw
buf := &bytes.Buffer{}
err = hdr.Write(buf, protocol.PerspectiveServer, ver)
Expect(err).ToNot(HaveOccurred())
hdr.Raw = buf.Bytes()
})
Context("unpacking", func() {
packPacket := func(frames []wire.Frame) []byte {
buf := &bytes.Buffer{}
err := hdr.Write(buf, protocol.PerspectiveClient, ver)
Expect(err).ToNot(HaveOccurred())
payloadStartIndex := buf.Len()
aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver)
for _, f := range frames {
err := f.Write(buf, ver)
Expect(err).ToNot(HaveOccurred())
}
raw := buf.Bytes()
return aeadCl.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
}
It("unpacks a packet", func() {
f := &wire.StreamFrame{
StreamID: 0,
Data: []byte("foobar"),
}
p := packPacket([]wire.Frame{f})
frame, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
})
It("rejects a packet that doesn't contain a STREAM_FRAME", func() {
p := packPacket([]wire.Frame{&wire.PingFrame{}})
_, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).To(MatchError("Packet doesn't contain a STREAM_FRAME"))
})
It("rejects a packet that has a STREAM_FRAME for the wrong stream", func() {
f := &wire.StreamFrame{
StreamID: 42,
Data: []byte("foobar"),
}
p := packPacket([]wire.Frame{f})
_, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).To(MatchError("UnencryptedStreamData: received unencrypted stream data on stream 42"))
})
It("rejects a packet that has a STREAM_FRAME with a non-zero offset", func() {
f := &wire.StreamFrame{
StreamID: 0,
Offset: 10,
Data: []byte("foobar"),
}
p := packPacket([]wire.Frame{f})
_, err := unpackInitialPacket(aead, hdr, p, ver)
Expect(err).To(MatchError("received stream data with non-zero offset"))
})
})
Context("packing", func() {
var unpacker *packetUnpacker
BeforeEach(func() {
aeadCl, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, ver)
Expect(err).ToNot(HaveOccurred())
unpacker = &packetUnpacker{aead: &nullAEAD{aeadCl}, version: ver}
})
It("packs a packet", func() {
f := &wire.StreamFrame{
Data: []byte("foobar"),
FinBit: true,
}
data, err := packUnencryptedPacket(aead, hdr, f, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
packet, err := unpacker.Unpack(hdr.Raw, hdr, data[len(hdr.Raw):])
Expect(err).ToNot(HaveOccurred())
Expect(packet.frames).To(Equal([]wire.Frame{f}))
})
})
})

View file

@ -17,9 +17,9 @@ type packetNumberGenerator struct {
nextToSkip protocol.PacketNumber
}
func newPacketNumberGenerator(averagePeriod protocol.PacketNumber) *packetNumberGenerator {
func newPacketNumberGenerator(initial, averagePeriod protocol.PacketNumber) *packetNumberGenerator {
return &packetNumberGenerator{
next: 1,
next: initial,
averagePeriod: averagePeriod,
}
}

View file

@ -12,7 +12,12 @@ var _ = Describe("Packet Number Generator", func() {
var png packetNumberGenerator
BeforeEach(func() {
png = *newPacketNumberGenerator(100)
png = *newPacketNumberGenerator(1, 100)
})
It("can be initialized to return any first packet number", func() {
png = *newPacketNumberGenerator(12345, 100)
Expect(png.Pop()).To(Equal(protocol.PacketNumber(12345)))
})
It("gets 1 as the first packet number", func() {

View file

@ -32,9 +32,11 @@ type packetPacker struct {
ackFrame *wire.AckFrame
leastUnacked protocol.PacketNumber
omitConnectionID bool
hasSentPacket bool // has the packetPacker already sent a packet
}
func newPacketPacker(connectionID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber,
cryptoSetup handshake.CryptoSetup,
streamFramer *streamFramer,
perspective protocol.Perspective,
@ -46,7 +48,7 @@ func newPacketPacker(connectionID protocol.ConnectionID,
perspective: perspective,
version: version,
streamFramer: streamFramer,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
}
}
@ -116,7 +118,12 @@ func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*
// PackPacket packs a new packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket() (*packedPacket, error) {
if p.streamFramer.HasCryptoStreamFrame() {
hasCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame()
// if this is the first packet to be send, make sure it contains stream data
if !p.hasSentPacket && !hasCryptoStreamFrame {
return nil, nil
}
if hasCryptoStreamFrame {
return p.packCryptoPacket()
}
@ -266,18 +273,21 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
pnum := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForHeader(pnum, p.leastUnacked)
var isLongHeader bool
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
// TODO: set the Long Header type
packetNumberLen = protocol.PacketNumberLen4
isLongHeader = true
}
header := &wire.Header{
ConnectionID: p.connectionID,
PacketNumber: pnum,
PacketNumberLen: packetNumberLen,
IsLongHeader: isLongHeader,
}
if p.version.UsesTLS() && encLevel != protocol.EncryptionForwardSecure {
header.PacketNumberLen = protocol.PacketNumberLen4
header.IsLongHeader = true
if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient {
header.Type = protocol.PacketTypeInitial
// TODO(#886): add padding
} else {
header.Type = protocol.PacketTypeHandshake
}
}
if p.omitConnectionID && encLevel == protocol.EncryptionForwardSecure {
@ -292,7 +302,6 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
header.Version = p.version
}
} else {
header.Type = p.cryptoSetup.GetNextPacketType()
if encLevel != protocol.EncryptionForwardSecure {
header.Version = p.version
}
@ -330,7 +339,7 @@ func (p *packetPacker) writeAndSealPacket(
if num != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
p.hasSentPacket = true
return raw, nil
}

View file

@ -28,7 +28,6 @@ type mockCryptoSetup struct {
divNonce []byte
encLevelSeal protocol.EncryptionLevel
encLevelSealCrypto protocol.EncryptionLevel
nextPacketType protocol.PacketType
}
var _ handshake.CryptoSetup = &mockCryptoSetup{}
@ -50,7 +49,6 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel)
}
func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce }
func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce }
func (m *mockCryptoSetup) GetNextPacketType() protocol.PacketType { return m.nextPacketType }
var _ = Describe("Packet packer", func() {
var (
@ -69,13 +67,14 @@ var _ = Describe("Packet packer", func() {
packer = &packetPacker{
cryptoSetup: &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure},
connectionID: 0x1337,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
packetNumberGenerator: newPacketNumberGenerator(1, protocol.SkipPacketAveragePeriodLength),
streamFramer: streamFramer,
perspective: protocol.PerspectiveServer,
}
publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number
maxFrameSize = protocol.MaxPacketSize - protocol.ByteCount((&mockSealer{}).Overhead()) - publicHeaderLen
packer.version = protocol.VersionWhatever
packer.hasSentPacket = true
})
It("returns nil when no packet is queued", func() {
@ -191,13 +190,6 @@ var _ = Describe("Packet packer", func() {
Expect(h.Version).To(Equal(versionIETFHeader))
})
It("sets the packet type based on the state of the handshake", func() {
packer.cryptoSetup.(*mockCryptoSetup).nextPacketType = 5
h := packer.getHeader(protocol.EncryptionSecure)
Expect(h.IsLongHeader).To(BeTrue())
Expect(h.Type).To(Equal(protocol.PacketType(5)))
})
It("uses the Short Header format for forward-secure packets", func() {
h := packer.getHeader(protocol.EncryptionForwardSecure)
Expect(h.IsLongHeader).To(BeFalse())
@ -269,7 +261,7 @@ var _ = Describe("Packet packer", func() {
Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber))
})
It("packs a StopWaitingFrame first", func() {
It("packs a STOP_WAITING frame first", func() {
packer.packetNumberGenerator.next = 15
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(&wire.RstStreamFrame{})
@ -281,7 +273,7 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames[0]).To(Equal(swf))
})
It("sets the LeastUnackedDelta length of a StopWaitingFrame", func() {
It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() {
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
packer.packetNumberGenerator.next = packetNumber
swf := &wire.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
@ -292,7 +284,7 @@ var _ = Describe("Packet packer", func() {
Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
})
It("does not pack a packet containing only a StopWaitingFrame", func() {
It("does not pack a packet containing only a STOP_WAITING frame", func() {
swf := &wire.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrame(swf)
p, err := packer.PackPacket()
@ -307,6 +299,14 @@ var _ = Describe("Packet packer", func() {
Expect(p).ToNot(BeNil())
})
It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() {
packer.hasSentPacket = false
packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}}
p, err := packer.PackPacket()
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
It("packs many control frames into 1 packets", func() {
f := &wire.AckFrame{LargestAcked: 1}
b := &bytes.Buffer{}
@ -602,7 +602,7 @@ var _ = Describe("Packet packer", func() {
})
})
Context("Blocked frames", func() {
Context("BLOCKED frames", func() {
It("queues a BLOCKED frame", func() {
length := 100
streamFramer.blockedFrameQueue = []wire.Frame{&wire.StreamBlockedFrame{StreamID: 5}}
@ -750,7 +750,7 @@ var _ = Describe("Packet packer", func() {
Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment"))
})
It("refuses to retransmit packets without a StopWaitingFrame", func() {
It("refuses to retransmit packets without a STOP_WAITING Frame", func() {
packer.stopWaiting = nil
_, err := packer.PackHandshakeRetransmission(&ackhandler.Packet{
EncryptionLevel: protocol.EncryptionSecure,

View file

@ -19,6 +19,7 @@ import (
// packetHandler handles packets
type packetHandler interface {
Session
getCryptoStream() cryptoStream
handshakeStatus() <-chan handshakeEvent
handlePacket(*receivedPacket)
GetVersion() protocol.VersionNumber
@ -33,6 +34,9 @@ type server struct {
conn net.PacketConn
supportsTLS bool
serverTLS *serverTLS
certChain crypto.CertChain
scfg *handshake.ServerConfig
@ -77,11 +81,21 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
if err != nil {
return nil, err
}
config = populateServerConfig(config)
// check if any of the supported versions supports TLS
var supportsTLS bool
for _, v := range config.Versions {
if v.UsesTLS() {
supportsTLS = true
break
}
}
s := &server{
conn: conn,
tlsConf: tlsConf,
config: populateServerConfig(config),
config: config,
certChain: certChain,
scfg: scfg,
sessions: map[protocol.ConnectionID]packetHandler{},
@ -89,12 +103,47 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
sessionQueue: make(chan Session, 5),
errorChan: make(chan struct{}),
supportsTLS: supportsTLS,
}
if supportsTLS {
if err := s.setupTLS(); err != nil {
return nil, err
}
}
go s.serve()
utils.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil
}
func (s *server) setupTLS() error {
cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie)
if err != nil {
return err
}
serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, cookieHandler, s.tlsConf)
if err != nil {
return err
}
s.serverTLS = serverTLS
// handle TLS connection establishment statelessly
go func() {
for {
select {
case <-s.errorChan:
return
case sess := <-sessionChan:
// TODO: think about what to do with connection ID collisions
connID := sess.(*session).connectionID
s.sessionsMutex.Lock()
s.sessions[connID] = sess
s.sessionsMutex.Unlock()
s.runHandshakeAndSession(sess, connID)
}
}
}()
return nil
}
var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
if cookie == nil {
return false
@ -225,8 +274,16 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
}
hdr.Raw = packet[:len(packet)-r.Len()]
packetData := packet[len(packet)-r.Len():]
connID := hdr.ConnectionID
if hdr.Type == protocol.PacketTypeInitial {
if s.supportsTLS {
go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
}
return nil
}
s.sessionsMutex.RLock()
session, sessionKnown := s.sessions[connID]
s.sessionsMutex.RUnlock()
@ -279,11 +336,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return err
}
}
// send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header
if hdr.Type == protocol.PacketTypeInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
_, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.config.Versions), remoteAddr)
return err
}
if !sessionKnown {
version := hdr.Version
@ -307,9 +359,21 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
s.sessions[connID] = session
s.sessionsMutex.Unlock()
s.runHandshakeAndSession(session, connID)
}
session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
return nil
}
func (s *server) runHandshakeAndSession(session packetHandler, connID protocol.ConnectionID) {
go func() {
// session.run() returns as soon as the session is closed
_ = session.run()
// session.run() returns as soon as the session is closed
s.removeConnection(connID)
}()
@ -326,14 +390,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
s.sessionQueue <- session
}()
}
session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
})
return nil
}
func (s *server) removeConnection(id protocol.ConnectionID) {
s.sessionsMutex.Lock()

View file

@ -12,6 +12,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
@ -67,6 +68,7 @@ func (s *mockSession) RemoteAddr() net.Addr { panic("not imple
func (*mockSession) Context() context.Context { panic("not implemented") }
func (*mockSession) GetVersion() protocol.VersionNumber { return protocol.VersionWhatever }
func (s *mockSession) handshakeStatus() <-chan handshakeEvent { return s.handshakeChan }
func (*mockSession) getCryptoStream() cryptoStream { panic("not implemented") }
var _ Session = &mockSession{}
var _ NonFWSession = &mockSession{}
@ -96,7 +98,8 @@ var _ = Describe("Server", func() {
)
BeforeEach(func() {
conn = &mockPacketConn{addr: &net.UDPAddr{}}
conn = newMockPacketConn()
conn.addr = &net.UDPAddr{}
config = &Config{Versions: protocol.SupportedVersions}
})
@ -235,14 +238,14 @@ var _ = Describe("Server", func() {
})
It("works if no quic.Config is given", func(done Done) {
ln, err := ListenAddr("127.0.0.1:0", nil, config)
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred())
Expect(ln.Close()).To(Succeed())
close(done)
}, 1)
It("closes properly", func() {
ln, err := ListenAddr("127.0.0.1:0", nil, config)
ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
var returned bool
@ -409,7 +412,7 @@ var _ = Describe("Server", func() {
}
hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */)
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
conn.dataToRead = b.Bytes()
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, nil, config)
Expect(err).ToNot(HaveOccurred())
@ -432,7 +435,7 @@ var _ = Describe("Server", func() {
})
It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
config.Versions = []protocol.VersionNumber{99}
config.Versions = []protocol.VersionNumber{99, protocol.VersionTLS}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
@ -441,11 +444,12 @@ var _ = Describe("Server", func() {
PacketNumber: 0x55,
Version: 0x1234,
}
hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
conn.dataToRead = b.Bytes()
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, nil, config)
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
@ -466,9 +470,32 @@ var _ = Describe("Server", func() {
Consistently(done).ShouldNot(BeClosed())
})
It("ignores IETF draft style Initial packets, if it doesn't support TLS", func() {
version := protocol.VersionNumber(99)
Expect(version.UsesTLS()).To(BeFalse())
config.Versions = []protocol.VersionNumber{version}
b := &bytes.Buffer{}
hdr := wire.Header{
Type: protocol.PacketTypeInitial,
IsLongHeader: true,
ConnectionID: 0x1337,
PacketNumber: 0x55,
Version: protocol.VersionTLS,
}
err := hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO
conn.dataToRead <- b.Bytes()
conn.dataReadFrom = udpAddr
ln, err := Listen(conn, testdata.GetTLSConfig(), config)
defer ln.Close()
Expect(err).ToNot(HaveOccurred())
Consistently(func() int { return conn.dataWritten.Len() }).Should(BeZero())
})
It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {
conn.dataReadFrom = udpAddr
conn.dataToRead = []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
conn.dataToRead <- []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}
ln, err := Listen(conn, nil, config)
Expect(err).ToNot(HaveOccurred())
go func() {

179
server_tls.go Normal file
View file

@ -0,0 +1,179 @@
package quic
import (
"crypto/tls"
"fmt"
"net"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type nullAEAD struct {
aead crypto.AEAD
}
var _ quicAEAD = &nullAEAD{}
func (n *nullAEAD) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
data, err := n.aead.Open(dst, src, packetNumber, associatedData)
return data, protocol.EncryptionUnencrypted, err
}
type serverTLS struct {
conn net.PacketConn
config *Config
supportedVersions []protocol.VersionNumber
mintConf *mint.Config
cookieProtector mint.CookieProtector
params *handshake.TransportParameters
newMintConn func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
sessionChan chan<- packetHandler
}
func newServerTLS(
conn net.PacketConn,
config *Config,
cookieHandler *handshake.CookieHandler,
tlsConf *tls.Config,
) (*serverTLS, <-chan packetHandler, error) {
mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer)
if err != nil {
return nil, nil, err
}
mconf.RequireCookie = true
cs, err := mint.NewDefaultCookieProtector()
if err != nil {
return nil, nil, err
}
mconf.CookieProtector = cs
mconf.CookieHandler = cookieHandler
sessionChan := make(chan packetHandler)
s := &serverTLS{
conn: conn,
config: config,
supportedVersions: config.Versions,
mintConf: mconf,
sessionChan: sessionChan,
params: &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: config.IdleTimeout,
},
}
s.newMintConn = s.newMintConnImpl
return s, sessionChan, nil
}
func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) {
utils.Debugf("Received a Packet. Handling it statelessly.")
sess, err := s.handleInitialImpl(remoteAddr, hdr, data)
if err != nil {
utils.Errorf("Error occured handling initial packet: %s", err)
return
}
if sess == nil { // a stateless reset was done
return
}
s.sessionChan <- sess
}
// will be set to s.newMintConn by the constructor
func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
conn := mint.Server(bc, s.mintConf)
extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v)
if err := conn.SetExtensionHandler(extHandler); err != nil {
return nil, nil, err
}
tls := newMintController(bc, s.mintConf, protocol.PerspectiveServer)
tls.SetExtensionHandler(extHandler)
return tls, extHandler.GetPeerParams(), nil
}
func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) {
// TODO: check length requirement
// check version, if not matching send VNP
if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
utils.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
_, err := s.conn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.supportedVersions), remoteAddr)
return nil, err
}
// unpack packet and check stream frame contents
version := hdr.Version
aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version)
if err != nil {
return nil, err
}
frame, err := unpackInitialPacket(aead, hdr, data, version)
if err != nil {
utils.Debugf("Error unpacking initial packet: %s", err)
return nil, nil
}
bc := handshake.NewCryptoStreamConn(remoteAddr)
bc.AddDataForReading(frame.Data)
tls, paramsChan, err := s.newMintConn(bc, hdr.Version)
if err != nil {
return nil, err
}
alert := tls.Handshake()
if alert == mint.AlertStatelessRetry {
// the HelloRetryRequest was written to the bufferConn
// Take that data and write send a Retry packet
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
ConnectionID: hdr.ConnectionID, // echo the client's connection ID
PacketNumber: hdr.PacketNumber, // echo the client's packet number
Version: version,
}
f := &wire.StreamFrame{
StreamID: version.CryptoStreamID(),
Data: bc.GetDataForWriting(),
}
data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer)
if err != nil {
return nil, err
}
_, err = s.conn.WriteTo(data, remoteAddr)
return nil, err
}
if alert != mint.AlertNoAlert {
return nil, alert
}
if tls.State() != mint.StateServerNegotiated {
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
}
if alert := tls.Handshake(); alert != mint.AlertNoAlert {
return nil, alert
}
if tls.State() != mint.StateServerWaitFlight2 {
return nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
}
params := <-paramsChan
sess, err := newTLSServerSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
hdr.ConnectionID, // TODO: we can use a server-chosen connection ID here
protocol.PacketNumber(1), // TODO: use a random packet number here
s.config,
tls,
bc,
aead,
&params,
version,
)
if err != nil {
return nil, err
}
cs := sess.getCryptoStream()
cs.SetReadOffset(frame.DataLen())
bc.SetStream(cs)
return sess, nil
}

116
server_tls_test.go Normal file
View file

@ -0,0 +1,116 @@
package quic
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/mocks/handshake"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Stateless TLS handling", func() {
var (
conn *mockPacketConn
server *serverTLS
sessionChan <-chan packetHandler
mintTLS *mockhandshake.MockMintTLS
extHandler *mocks.MockTLSExtensionHandler
mintReply io.Writer
)
BeforeEach(func() {
mintTLS = mockhandshake.NewMockMintTLS(mockCtrl)
extHandler = mocks.NewMockTLSExtensionHandler(mockCtrl)
conn = newMockPacketConn()
config := &Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
var err error
server, sessionChan, err = newServerTLS(conn, config, nil, testdata.GetTLSConfig())
Expect(err).ToNot(HaveOccurred())
server.newMintConn = func(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
mintReply = bc
return mintTLS, extHandler.GetPeerParams(), nil
}
})
getPacket := func(f wire.Frame) (*wire.Header, []byte) {
hdrBuf := &bytes.Buffer{}
hdr := &wire.Header{
IsLongHeader: true,
PacketNumber: 1,
Version: protocol.VersionTLS,
}
err := hdr.Write(hdrBuf, protocol.PerspectiveClient, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
hdr.Raw = hdrBuf.Bytes()
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, 0, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
buf := &bytes.Buffer{}
err = f.Write(buf, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
return hdr, aead.Seal(nil, buf.Bytes(), 1, hdr.Raw)
}
It("sends a version negotiation packet if it doesn't support the version", func() {
server.HandleInitial(nil, &wire.Header{Version: 0x1337}, nil)
Expect(conn.dataWritten.Len()).ToNot(BeZero())
hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionUnknown)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsVersionNegotiation).To(BeTrue())
Expect(sessionChan).ToNot(Receive())
})
It("ignores packets with invalid contents", func() {
hdr, data := getPacket(&wire.StreamFrame{StreamID: 10, Offset: 11, Data: []byte("foobar")})
server.HandleInitial(nil, hdr, data)
Expect(conn.dataWritten.Len()).To(BeZero())
Expect(sessionChan).ToNot(Receive())
})
It("replies with a Retry packet, if a Cookie is required", func() {
extHandler.EXPECT().GetPeerParams()
mintTLS.EXPECT().Handshake().Return(mint.AlertStatelessRetry).Do(func() {
mintReply.Write([]byte("Retry with this Cookie"))
})
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
server.HandleInitial(nil, hdr, data)
Expect(conn.dataWritten.Len()).ToNot(BeZero())
hdr, err := wire.ParseHeaderSentByServer(bytes.NewReader(conn.dataWritten.Bytes()), protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(sessionChan).ToNot(Receive())
})
It("replies with a Handshake packet and creates a session, if no Cookie is required", func() {
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert).Do(func() {
mintReply.Write([]byte("Server Hello"))
})
mintTLS.EXPECT().Handshake().Return(mint.AlertNoAlert)
mintTLS.EXPECT().State().Return(mint.StateServerNegotiated)
mintTLS.EXPECT().State().Return(mint.StateServerWaitFlight2)
paramsChan := make(chan handshake.TransportParameters, 1)
paramsChan <- handshake.TransportParameters{}
extHandler.EXPECT().GetPeerParams().Return(paramsChan)
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
done := make(chan struct{})
go func() {
defer GinkgoRecover()
server.HandleInitial(nil, hdr, data)
// the Handshake packet is written by the session
Expect(conn.dataWritten.Len()).To(BeZero())
close(done)
}()
Eventually(sessionChan).Should(Receive())
Eventually(done).Should(BeClosed())
})
})

View file

@ -11,6 +11,7 @@ import (
"github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
@ -60,7 +61,7 @@ type session struct {
conn connection
streamsMap *streamsMap
cryptoStream streamI
cryptoStream cryptoStream
rttStats *congestion.RTTStats
@ -126,21 +127,48 @@ func newSession(
conn connection,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
sCfg *handshake.ServerConfig,
scfg *handshake.ServerConfig,
tlsConf *tls.Config,
config *Config,
) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s := &session{
conn: conn,
connectionID: connectionID,
perspective: protocol.PerspectiveServer,
version: v,
config: config,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
}
return s, s.setup(sCfg, "", tlsConf, v, nil)
s.preSetup()
transportParams := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: s.config.IdleTimeout,
}
cs, err := newCryptoSetup(
s.cryptoStream,
s.connectionID,
s.conn.RemoteAddr(),
s.version,
scfg,
transportParams,
s.config.Versions,
s.config.AcceptCookie,
paramsChan,
aeadChanged,
)
if err != nil {
return nil, err
}
s.cryptoSetup = cs
return s, s.postSetup(1)
}
// declare this as a variable, such that we can it mock it in the tests
// declare this as a variable, so that we can it mock it in the tests
var newClientSession = func(
conn connection,
hostname string,
@ -151,27 +179,130 @@ var newClientSession = func(
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber, // needed for validation of the GQUIC version negotiaton
) (packetHandler, error) {
paramsChan := make(chan handshake.TransportParameters)
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s := &session{
conn: conn,
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
config: config,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
}
return s, s.setup(nil, hostname, tlsConf, initialVersion, negotiatedVersions)
s.preSetup()
transportParams := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: s.config.IdleTimeout,
OmitConnectionID: s.config.RequestConnectionIDOmission,
}
cs, err := newCryptoSetupClient(
s.cryptoStream,
hostname,
s.connectionID,
s.version,
tlsConf,
transportParams,
paramsChan,
aeadChanged,
initialVersion,
negotiatedVersions,
)
if err != nil {
return nil, err
}
s.cryptoSetup = cs
return s, s.postSetup(1)
}
func (s *session) setup(
scfg *handshake.ServerConfig,
hostname string,
tlsConf *tls.Config,
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
) error {
func newTLSServerSession(
conn connection,
connectionID protocol.ConnectionID,
initialPacketNumber protocol.PacketNumber,
config *Config,
tls handshake.MintTLS,
cryptoStreamConn *handshake.CryptoStreamConn,
nullAEAD crypto.AEAD,
peerParams *handshake.TransportParameters,
v protocol.VersionNumber,
) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2)
paramsChan := make(chan handshake.TransportParameters)
s.aeadChanged = aeadChanged
s.paramsChan = paramsChan
s := &session{
conn: conn,
config: config,
connectionID: connectionID,
perspective: protocol.PerspectiveServer,
version: v,
aeadChanged: aeadChanged,
}
s.preSetup()
s.cryptoSetup = handshake.NewCryptoSetupTLSServer(
tls,
cryptoStreamConn,
nullAEAD,
aeadChanged,
v,
)
if err := s.postSetup(initialPacketNumber); err != nil {
return nil, err
}
s.peerParams = peerParams
s.processTransportParameters(peerParams)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, nil
}
// declare this as a variable, such that we can it mock it in the tests
var newTLSClientSession = func(
conn connection,
hostname string,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
config *Config,
tls handshake.MintTLS,
paramsChan <-chan handshake.TransportParameters,
initialPacketNumber protocol.PacketNumber,
) (packetHandler, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s := &session{
conn: conn,
config: config,
connectionID: connectionID,
perspective: protocol.PerspectiveClient,
version: v,
aeadChanged: aeadChanged,
paramsChan: paramsChan,
}
s.preSetup()
tls.SetCryptoStream(s.cryptoStream)
cs, err := handshake.NewCryptoSetupTLSClient(
s.cryptoStream,
s.connectionID,
hostname,
aeadChanged,
tls,
v,
)
if err != nil {
return nil, err
}
s.cryptoSetup = cs
return s, s.postSetup(initialPacketNumber)
}
func (s *session) preSetup() {
s.rttStats = &congestion.RTTStats{}
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.cryptoStream = s.newStream(s.version.CryptoStreamID()).(cryptoStream)
}
func (s *session) postSetup(initialPacketNumber protocol.PacketNumber) error {
s.handshakeChan = make(chan handshakeEvent, 3)
s.handshakeCompleteChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
@ -185,91 +316,14 @@ func (s *session) setup(
s.lastNetworkActivityTime = now
s.sessionCreationTime = now
s.rttStats = &congestion.RTTStats{}
transportParams := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
MaxStreams: protocol.MaxIncomingStreams,
IdleTimeout: s.config.IdleTimeout,
}
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.version)
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ReceiveConnectionFlowControlWindow,
protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow),
s.rttStats,
)
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.version)
s.cryptoStream = s.newStream(s.version.CryptoStreamID())
s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.connFlowController)
var err error
if s.perspective == protocol.PerspectiveServer {
verifySourceAddr := func(clientAddr net.Addr, cookie *Cookie) bool {
return s.config.AcceptCookie(clientAddr, cookie)
}
if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLSServer(
s.cryptoStream,
s.connectionID,
tlsConf,
s.conn.RemoteAddr(),
transportParams,
paramsChan,
aeadChanged,
verifySourceAddr,
s.config.Versions,
s.version,
)
} else {
s.cryptoSetup, err = newCryptoSetup(
s.cryptoStream,
s.connectionID,
s.conn.RemoteAddr(),
s.version,
scfg,
transportParams,
s.config.Versions,
verifySourceAddr,
paramsChan,
aeadChanged,
)
}
} else {
transportParams.OmitConnectionID = s.config.RequestConnectionIDOmission
if s.version.UsesTLS() {
s.cryptoSetup, err = handshake.NewCryptoSetupTLSClient(
s.cryptoStream,
s.connectionID,
hostname,
tlsConf,
transportParams,
paramsChan,
aeadChanged,
initialVersion,
s.config.Versions,
s.version,
)
} else {
s.cryptoSetup, err = newCryptoSetupClient(
s.cryptoStream,
hostname,
s.connectionID,
s.version,
tlsConf,
transportParams,
paramsChan,
aeadChanged,
initialVersion,
negotiatedVersions,
)
}
}
if err != nil {
return err
}
s.packer = newPacketPacker(s.connectionID,
initialPacketNumber,
s.cryptoSetup,
s.streamFramer,
s.perspective,
@ -604,7 +658,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
s.cryptoStream.Cancel(quicErr)
s.streamsMap.CloseWithError(quicErr)
if closeErr.err == errCloseSessionForNewVersion {
if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry {
return nil
}
@ -893,6 +947,10 @@ func (s *session) handshakeStatus() <-chan handshakeEvent {
return s.handshakeChan
}
func (s *session) getCryptoStream() cryptoStream {
return s.cryptoStream
}
func (s *session) GetVersion() protocol.VersionNumber {
return s.version
}

View file

@ -468,9 +468,6 @@ var _ = Describe("Session", func() {
})
It("handles CONNECTION_CLOSE frames", func() {
cryptoStream := mocks.NewMockStreamI(mockCtrl)
cryptoStream.EXPECT().Cancel(gomock.Any())
sess.cryptoStream = cryptoStream
done := make(chan struct{})
go func() {
defer GinkgoRecover()
@ -771,10 +768,15 @@ var _ = Describe("Session", func() {
})
Context("sending packets", func() {
BeforeEach(func() {
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
})
It("sends ACK frames", func() {
packetNumber := protocol.PacketNumber(0x035e)
sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
err := sess.sendPacket()
err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
Expect(err).ToNot(HaveOccurred())
err = sess.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e}))))
@ -858,6 +860,7 @@ var _ = Describe("Session", func() {
BeforeEach(func() {
// a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet
sess.packer.packetNumberGenerator.next = 0x1337 + 10
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
sph = newMockSentPacketHandler().(*mockSentPacketHandler)
sess.sentPacketHandler = sph
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
@ -981,6 +984,7 @@ var _ = Describe("Session", func() {
})
It("retransmits RTO packets", func() {
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
sess.sentPacketHandler.SetHandshakeComplete()
n := protocol.PacketNumber(10)
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
@ -1008,6 +1012,7 @@ var _ = Describe("Session", func() {
Context("scheduling sending", func() {
BeforeEach(func() {
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
sess.processTransportParameters(&handshake.TransportParameters{
StreamFlowControlWindow: protocol.MaxByteCount,
ConnectionFlowControlWindow: protocol.MaxByteCount,
@ -1291,6 +1296,7 @@ var _ = Describe("Session", func() {
sess.handshakeComplete = true
sess.config.KeepAlive = true
sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2)
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
go sess.run()
defer sess.Close(nil)
var data []byte
@ -1551,7 +1557,10 @@ var _ = Describe("Client Session", func() {
})
It("passes the diversification nonce to the cryptoSetup", func() {
go sess.run()
go func() {
defer GinkgoRecover()
sess.run()
}()
hdr.PacketNumber = 5
hdr.DiversificationNonce = []byte("foobar")
err := sess.handlePacketImpl(&receivedPacket{header: hdr})

View file

@ -32,6 +32,11 @@ type streamI interface {
IsFlowControlBlocked() bool
}
type cryptoStream interface {
streamI
SetReadOffset(protocol.ByteCount)
}
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
//
// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually.
@ -481,3 +486,11 @@ func (s *stream) IsFlowControlBlocked() bool {
func (s *stream) GetWindowUpdate() protocol.ByteCount {
return s.flowController.GetWindowUpdate()
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *stream) SetReadOffset(offset protocol.ByteCount) {
s.readOffset = offset
s.frameQueue.readPosition = offset
}

View file

@ -266,6 +266,12 @@ var _ = Describe("Stream", func() {
Expect(onDataCalled).To(BeTrue())
})
It("sets the read offset", func() {
str.SetReadOffset(0x42)
Expect(str.readOffset).To(Equal(protocol.ByteCount(0x42)))
Expect(str.frameQueue.readPosition).To(Equal(protocol.ByteCount(0x42)))
})
Context("deadlines", func() {
It("the deadline error has the right net.Error properties", func() {
Expect(errDeadline.Temporary()).To(BeTrue())

View file

@ -325,4 +325,5 @@ func (m *streamsMap) UpdateMaxStreamLimit(limit uint32) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.maxOutgoingStreams = limit
m.openStreamOrErrCond.Broadcast()
}