drop support for gQUIC

This commit is contained in:
Marten Seemann 2018-10-24 09:34:44 +07:00
parent 8f8ed03254
commit 3266e36811
195 changed files with 2638 additions and 35430 deletions

View file

@ -14,7 +14,6 @@ defaults: &defaults
command: |
echo $GOARCH
go version
google-chrome --version
printf "quic.clemente.io certificate valid until: " && openssl x509 -in example/fullchain.pem -enddate -noout | cut -d = -f 2
- run:
name: "Run benchmark tests"
@ -25,12 +24,6 @@ defaults: &defaults
- run:
name: "Run tools tests"
command: ginkgo -r -v -randomizeAllSpecs -trace integrationtests/tools
- run:
name: "Run Chrome integration tests"
command: ginkgo -v -randomizeAllSpecs -trace integrationtests/chrome
- run:
name: "Run integration tests using the gQUIC toy server / client"
command: ginkgo -v -randomizeAllSpecs -trace integrationtests/gquic
- run:
name: "Run self integration tests"
command: ginkgo -v -randomizeAllSpecs -trace integrationtests/self

View file

@ -22,5 +22,5 @@ if [ ${TESTMODE} == "integration" ]; then
ginkgo -race -randomizeAllSpecs -randomizeSuites -trace benchmark -- -samples=1 -size=10
fi
# run integration tests
ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage chrome integrationtests
ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage integrationtests
fi

View file

@ -8,12 +8,19 @@
[![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/)
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go. It roughly implements the [IETF QUIC draft](https://github.com/quicwg/base-drafts), although we don't fully support any of the draft versions at the moment.
## Roadmap
## Version compatibility
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
Since quic-go is under active development, there's no guarantee that two builds of different commits are interoperable. The QUIC version used in the *master* branch is just a placeholder, and should not be considered stable.
If you want to use quic-go as a library in other projects, please consider using a [tagged release](https://github.com/lucas-clemente/quic-go/releases). These releases expose [experimental QUIC versions](https://github.com/quicwg/base-drafts/wiki/QUIC-Versions), which are guaranteed to be stable.
## Google QUIC
quic-go used to support both the QUIC versions supported by Google Chrome and QUIC as deployed on Google's servers, as well as IETF QUIC. Due to the divergence of the two protocols, we decided to not support both versions any more.
The *master* branch **only** supports IETF QUIC. For Google QUIC support, please refer to the [gquic branch](https://github.com/lucas-clemente/quic-go/tree/gquic).
## Guides
@ -27,31 +34,19 @@ Running tests:
go test ./...
### Running the example server
### HTTP mapping
go run example/main.go -www /var/www/
Using the `quic_client` from chromium:
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
Using Chrome:
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
We're currently not implementing the HTTP mapping as described in the [QUIC over HTTP draft](https://quicwg.org/base-drafts/draft-ietf-quic-http.html). The HTTP mapping here is a leftover from Google QUIC.
### QUIC without HTTP/2
Take a look at [this echo example](example/echo/echo.go).
### Using the example client
go run example/client/main.go https://clemente.io
## Usage
### As a server
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
```go
http.Handle("/", http.FileServer(http.Dir(wwwDir)))

127
client.go
View file

@ -1,7 +1,6 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
@ -123,13 +122,6 @@ func dialContext(
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
if !createdPacketConn {
for _, v := range config.Versions {
if v == protocol.Version44 {
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
}
}
}
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
@ -234,17 +226,11 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
if connIDLen == 0 && !createdPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
for _, v := range versions {
if v == protocol.Version44 {
connIDLen = 0
}
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
ConnectionIDLength: connIDLen,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
@ -255,53 +241,22 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
}
func (c *client) generateConnectionIDs() error {
connIDLen := protocol.ConnectionIDLenGQUIC
if c.version.UsesTLS() {
connIDLen = c.config.ConnectionIDLength
}
srcConnID, err := generateConnectionID(connIDLen)
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
if err != nil {
return err
}
destConnID := srcConnID
if c.version.UsesTLS() {
destConnID, err = generateConnectionIDForInitial()
if err != nil {
return err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return err
}
c.srcConnID = srcConnID
c.destConnID = destConnID
if c.version == protocol.Version44 {
c.srcConnID = nil
}
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
var err error
if c.version.UsesTLS() {
err = c.dialTLS(ctx)
} else {
err = c.dialGQUIC(ctx)
}
return err
}
func (c *client) dialGQUIC(ctx context.Context) error {
if err := c.createNewGQUICSession(); err != nil {
return err
}
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return err
}
func (c *client) dialTLS(ctx context.Context) error {
if err := c.createNewTLSSession(c.version); err != nil {
return err
}
@ -315,9 +270,9 @@ func (c *client) dialTLS(ctx context.Context) error {
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
// - when the connection is forward-secure
func (c *client) establishSecureConnection(ctx context.Context) error {
errorChan := make(chan error, 1)
@ -362,24 +317,9 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
return err
}
if !c.version.UsesIETFHeaderFormat() {
connID := p.header.DestConnectionID
// reject packets with truncated connection id if we didn't request truncation
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
// reject packets with the wrong connection ID
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
}
if p.header.ResetFlag {
return c.handlePublicReset(p)
}
} else {
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
if p.header.Type == protocol.PacketTypeRetry {
@ -397,22 +337,6 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
return nil
}
func (c *client) handlePublicReset(p *receivedPacket) error {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
return errors.New("Received a spoofed Public Reset")
}
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
if err != nil {
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
}
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
return nil
}
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
@ -471,42 +395,11 @@ func (c *client) handleRetryPacket(hdr *wire.Header) {
c.session.destroy(errCloseSessionForRetry)
}
func (c *client) createNewGQUICSession() error {
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newClientSession(
c.conn,
runner,
c.version,
c.destConnID,
c.srcConnID,
c.tlsConf,
c.config,
c.initialVersion,
c.negotiatedVersions,
c.logger,
)
if err != nil {
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
if c.config.RequestConnectionIDOmission {
c.packetHandlers.Add(protocol.ConnectionID{}, c)
}
return nil
}
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
DisableMigration: true,
@ -518,7 +411,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newTLSClientSession(
sess, err := newClientSession(
c.conn,
runner,
c.token,

View file

@ -30,9 +30,20 @@ var _ = Describe("Client", func() {
mockMultiplexer *MockMultiplexer
origMultiplexer multiplexer
supportedVersionsWithoutGQUIC44 []protocol.VersionNumber
originalClientSessConstructor func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (quicSession, error)
originalClientSessConstructor func(
conn connection,
runner sessionRunner,
token []byte,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
conf *Config,
tlsConf *tls.Config,
params *handshake.TransportParameters,
initialVersion protocol.VersionNumber,
initialPacketNumber protocol.PacketNumber,
logger utils.Logger,
v protocol.VersionNumber,
) (quicSession, error)
)
// generate a packet sent by the server that accepts the QUIC version suggested by the client
@ -79,12 +90,6 @@ var _ = Describe("Client", func() {
mockMultiplexer = NewMockMultiplexer(mockCtrl)
origMultiplexer = connMuxer
connMuxer = mockMultiplexer
for _, v := range protocol.SupportedVersions {
if v != protocol.Version44 {
supportedVersionsWithoutGQUIC44 = append(supportedVersionsWithoutGQUIC44, v)
}
}
Expect(supportedVersionsWithoutGQUIC44).ToNot(BeEmpty())
})
AfterEach(func() {
@ -131,14 +136,16 @@ var _ = Describe("Client", func() {
newClientSession = func(
conn connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
remoteAddrChan <- conn.RemoteAddr().String()
sess := NewMockQuicSession(mockCtrl)
@ -159,14 +166,16 @@ var _ = Describe("Client", func() {
newClientSession = func(
_ connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
tlsConf *tls.Config,
_ *Config,
tlsConf *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
hostnameChan <- tlsConf.ServerName
sess := NewMockQuicSession(mockCtrl)
@ -187,14 +196,16 @@ var _ = Describe("Client", func() {
newClientSession = func(
_ connection,
runner sessionRunner,
_ protocol.VersionNumber,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Do(func() { close(run) })
@ -206,25 +217,13 @@ var _ = Describe("Client", func() {
addr,
"quic.clemente.io:1337",
nil,
&Config{Versions: supportedVersionsWithoutGQUIC44},
&Config{},
)
Expect(err).ToNot(HaveOccurred())
Expect(s).ToNot(BeNil())
Eventually(run).Should(BeClosed())
})
It("refuses to multiplex gQUIC 44", func() {
_, err := Dial(
packetConn,
addr,
"quic.clemente.io:1337",
nil,
&Config{Versions: []protocol.VersionNumber{protocol.Version44}},
)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("Cannot multiplex connections using gQUIC 44"))
})
It("returns an error that occurs while waiting for the connection to become secure", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
@ -232,16 +231,18 @@ var _ = Describe("Client", func() {
testErr := errors.New("early handshake error")
newClientSession = func(
conn connection,
_ connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(testErr)
@ -253,7 +254,7 @@ var _ = Describe("Client", func() {
addr,
"quic.clemente.io:1337",
nil,
&Config{Versions: supportedVersionsWithoutGQUIC44},
&Config{},
)
Expect(err).To(MatchError(testErr))
})
@ -270,16 +271,18 @@ var _ = Describe("Client", func() {
<-sessionRunning
})
newClientSession = func(
conn connection,
_ connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
return sess, nil
}
@ -293,7 +296,7 @@ var _ = Describe("Client", func() {
addr,
"quic.clemnte.io:1337",
nil,
&Config{Versions: supportedVersionsWithoutGQUIC44},
&Config{},
)
Expect(err).To(MatchError(context.Canceled))
close(dialed)
@ -313,16 +316,18 @@ var _ = Describe("Client", func() {
var runner sessionRunner
sess := NewMockQuicSession(mockCtrl)
newClientSession = func(
conn connection,
_ connection,
runnerP sessionRunner,
_ protocol.VersionNumber,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
runner = runnerP
return sess, nil
@ -337,7 +342,7 @@ var _ = Describe("Client", func() {
addr,
"quic.clemnte.io:1337",
nil,
&Config{Versions: supportedVersionsWithoutGQUIC44},
&Config{},
)
Expect(err).ToNot(HaveOccurred())
})
@ -354,14 +359,16 @@ var _ = Describe("Client", func() {
newClientSession = func(
connP connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
conn = connP
close(sessionCreated)
@ -397,39 +404,20 @@ var _ = Describe("Client", func() {
Context("quic.Config", func() {
It("setups with the right values", func() {
config := &Config{
HandshakeTimeout: 1337 * time.Minute,
IdleTimeout: 42 * time.Hour,
RequestConnectionIDOmission: true,
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: 4321,
ConnectionIDLength: 13,
Versions: supportedVersionsWithoutGQUIC44,
HandshakeTimeout: 1337 * time.Minute,
IdleTimeout: 42 * time.Hour,
MaxIncomingStreams: 1234,
MaxIncomingUniStreams: 4321,
ConnectionIDLength: 13,
}
c := populateClientConfig(config, false)
Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute))
Expect(c.IdleTimeout).To(Equal(42 * time.Hour))
Expect(c.RequestConnectionIDOmission).To(BeTrue())
Expect(c.MaxIncomingStreams).To(Equal(1234))
Expect(c.MaxIncomingUniStreams).To(Equal(4321))
Expect(c.ConnectionIDLength).To(Equal(13))
})
It("uses a 0 byte connection IDs if gQUIC 44 is supported", func() {
config := &Config{
Versions: []protocol.VersionNumber{protocol.Version43, protocol.Version44},
ConnectionIDLength: 13,
}
c := populateClientConfig(config, false)
Expect(c.Versions).To(Equal([]protocol.VersionNumber{protocol.Version43, protocol.Version44}))
Expect(c.ConnectionIDLength).To(BeZero())
})
It("doesn't use 0-byte connection IDs when dialing an address", func() {
config := &Config{Versions: supportedVersionsWithoutGQUIC44}
c := populateClientConfig(config, false)
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
})
It("errors when the Config contains an invalid version", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
@ -470,194 +458,160 @@ var _ = Describe("Client", func() {
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout))
Expect(c.IdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
Expect(c.RequestConnectionIDOmission).To(BeFalse())
})
})
Context("gQUIC", func() {
It("errors if it can't create a session", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
It("creates new TLS sessions with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
testErr := errors.New("error creating session")
newClientSession = func(
_ connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (quicSession, error) {
return nil, testErr
}
_, err := Dial(
packetConn,
addr,
"quic.clemente.io:1337",
nil,
&Config{Versions: supportedVersionsWithoutGQUIC44},
)
Expect(err).To(MatchError(testErr))
})
})
Context("IETF QUIC", func() {
It("creates new TLS sessions with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
c := make(chan struct{})
var cconn connection
var version protocol.VersionNumber
var conf *Config
newTLSClientSession = func(
connP connection,
_ sessionRunner,
tokenP []byte,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
configP *Config,
_ *tls.Config,
params *handshake.TransportParameters,
_ protocol.VersionNumber, /* initial version */
_ protocol.PacketNumber,
_ utils.Logger,
versionP protocol.VersionNumber,
) (quicSession, error) {
cconn = connP
version = versionP
conf = configP
close(c)
// TODO: check connection IDs?
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
return sess, nil
}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
})
It("creates a new session when the server performs a retry", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) {
go handler.handlePacket(&receivedPacket{
header: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
Token: []byte("foobar"),
DestConnectionID: id,
OrigDestConnectionID: connID,
},
})
})
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
cl.config = config
run1 := make(chan error)
sess1 := NewMockQuicSession(mockCtrl)
sess1.EXPECT().run().DoAndReturn(func() error {
return <-run1
})
sess1.EXPECT().destroy(errCloseSessionForRetry).Do(func(e error) {
run1 <- e
})
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessions := make(chan quicSession, 2)
sessions <- sess1
sessions <- sess2
newTLSClientSession = func(
_ connection,
_ sessionRunner,
_ []byte,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber, /* initial version */
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
return <-sessions, nil
}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(sessions).To(BeEmpty())
})
It("only accepts a single retry", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) {
go handler.handlePacket(&receivedPacket{
header: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
Token: []byte("foobar"),
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4},
DestConnectionID: id,
OrigDestConnectionID: connID,
Version: protocol.VersionTLS,
},
})
}).AnyTimes()
manager.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes()
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
cl.config = config
sessions := make(chan quicSession, 2)
run := make(chan error)
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
c := make(chan struct{})
var cconn connection
var version protocol.VersionNumber
var conf *Config
newClientSession = func(
connP connection,
_ sessionRunner,
tokenP []byte,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
configP *Config,
_ *tls.Config,
params *handshake.TransportParameters,
_ protocol.VersionNumber, /* initial version */
_ protocol.PacketNumber,
_ utils.Logger,
versionP protocol.VersionNumber,
) (quicSession, error) {
cconn = connP
version = versionP
conf = configP
close(c)
// TODO: check connection IDs?
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().DoAndReturn(func() error {
defer GinkgoRecover()
var err error
Eventually(run).Should(Receive(&err))
return err
})
sess.EXPECT().destroy(gomock.Any()).Do(func(e error) {
run <- e
})
sessions <- sess
doneErr := errors.New("nothing to do")
sess = NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(doneErr)
sessions <- sess
sess.EXPECT().run()
return sess, nil
}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
})
newTLSClientSession = func(
_ connection,
_ sessionRunner,
_ []byte,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber, /* initial version */
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
return <-sessions, nil
}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(doneErr))
Expect(sessions).To(BeEmpty())
It("creates a new session when the server performs a retry", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) {
go handler.handlePacket(&receivedPacket{
header: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
Token: []byte("foobar"),
DestConnectionID: id,
OrigDestConnectionID: connID,
},
})
})
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
cl.config = config
run1 := make(chan error)
sess1 := NewMockQuicSession(mockCtrl)
sess1.EXPECT().run().DoAndReturn(func() error {
return <-run1
})
sess1.EXPECT().destroy(errCloseSessionForRetry).Do(func(e error) {
run1 <- e
})
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessions := make(chan quicSession, 2)
sessions <- sess1
sessions <- sess2
newClientSession = func(
conn connection,
_ sessionRunner,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
return <-sessions, nil
}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
Expect(sessions).To(BeEmpty())
})
It("only accepts a single retry", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) {
go handler.handlePacket(&receivedPacket{
header: &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
Token: []byte("foobar"),
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4},
DestConnectionID: id,
OrigDestConnectionID: connID,
Version: protocol.VersionTLS,
},
})
}).AnyTimes()
manager.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes()
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
cl.config = config
sessions := make(chan quicSession, 2)
run := make(chan error)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().DoAndReturn(func() error {
defer GinkgoRecover()
var err error
Eventually(run).Should(Receive(&err))
return err
})
sess.EXPECT().destroy(gomock.Any()).Do(func(e error) {
run <- e
})
sessions <- sess
doneErr := errors.New("nothing to do")
sess = NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(doneErr)
sessions <- sess
newClientSession = func(
conn connection,
_ sessionRunner,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
return <-sessions, nil
}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(doneErr))
Expect(sessions).To(BeEmpty())
})
Context("version negotiation", func() {
@ -681,14 +635,16 @@ var _ = Describe("Client", func() {
newClientSession = func(
conn connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ []byte, // token
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ *tls.Config,
_ *handshake.TransportParameters,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ protocol.PacketNumber,
_ utils.Logger,
_ protocol.VersionNumber,
) (quicSession, error) {
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
sess := NewMockQuicSession(mockCtrl)
@ -700,7 +656,7 @@ var _ = Describe("Client", func() {
addr,
"quic.clemente.io:1337",
nil,
&Config{Versions: supportedVersionsWithoutGQUIC44},
&Config{},
)
Expect(err).To(MatchError(testErr))
})
@ -721,103 +677,6 @@ var _ = Describe("Client", func() {
Expect(cl.versionNegotiated).To(BeTrue())
})
It("changes the version after receiving a Version Negotiation Packet", func() {
phm := NewMockPacketHandlerManager(mockCtrl)
phm.EXPECT().Add(connID, gomock.Any()).Times(2)
cl.packetHandlers = phm
version1 := protocol.Version39
version2 := protocol.Version39 + 1
Expect(version2.UsesTLS()).To(BeFalse())
sess1 := NewMockQuicSession(mockCtrl)
run1 := make(chan struct{})
sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion)
sess1.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessionChan := make(chan *MockQuicSession, 2)
sessionChan <- sess1
sessionChan <- sess2
newClientSession = func(
_ connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (quicSession, error) {
return <-sessionChan, nil
}
cl.tlsConf = &tls.Config{}
cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2}}
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cl.dial(context.Background())
Expect(err).ToNot(HaveOccurred())
close(dialed)
}()
Eventually(sessionChan).Should(HaveLen(1))
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version2}))
Eventually(sessionChan).Should(BeEmpty())
})
It("only accepts one version negotiation packet", func() {
phm := NewMockPacketHandlerManager(mockCtrl)
phm.EXPECT().Add(connID, gomock.Any()).Times(2)
cl.packetHandlers = phm
version1 := protocol.Version39
version2 := protocol.Version39 + 1
version3 := protocol.Version39 + 2
Expect(version2.UsesTLS()).To(BeFalse())
Expect(version3.UsesTLS()).To(BeFalse())
sess1 := NewMockQuicSession(mockCtrl)
run1 := make(chan struct{})
sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion)
sess1.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
sess2 := NewMockQuicSession(mockCtrl)
sess2.EXPECT().run()
sessionChan := make(chan *MockQuicSession, 2)
sessionChan <- sess1
sessionChan <- sess2
newClientSession = func(
_ connection,
_ sessionRunner,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (quicSession, error) {
return <-sessionChan, nil
}
cl.tlsConf = &tls.Config{}
cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2, version3}}
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
err := cl.dial(context.Background())
Expect(err).ToNot(HaveOccurred())
close(dialed)
}()
Eventually(sessionChan).Should(HaveLen(1))
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version2}))
Eventually(sessionChan).Should(BeEmpty())
Expect(cl.version).To(Equal(version2))
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version3}))
Eventually(dialed).Should(BeClosed())
Expect(cl.version).To(Equal(version2))
})
It("errors if no matching version is found", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().destroy(qerr.InvalidVersion)
@ -863,26 +722,8 @@ var _ = Describe("Client", func() {
Expect(cl.GetVersion()).To(Equal(cl.version))
})
It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
cl.version = versionGQUICFrames
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
cl.config = &Config{RequestConnectionIDOmission: false}
hdr := &wire.Header{
IsPublicHeader: true,
PacketNumber: 1,
PacketNumberLen: 1,
}
err := cl.handlePacketImpl(&receivedPacket{
remoteAddr: addr,
header: hdr,
})
Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation"))
})
It("ignores packets with the wrong destination connection ID", func() {
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
cl.version = versionIETFFrames
cl.config = &Config{RequestConnectionIDOmission: false}
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
Expect(connID).ToNot(Equal(connID2))
hdr := &wire.Header{
@ -890,7 +731,6 @@ var _ = Describe("Client", func() {
SrcConnectionID: connID,
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen1,
Version: versionIETFFrames,
}
err := cl.handlePacketImpl(&receivedPacket{
remoteAddr: addr,
@ -898,110 +738,4 @@ var _ = Describe("Client", func() {
})
Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID)))
})
It("creates new gQUIC sessions with the right parameters", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
c := make(chan struct{})
var cconn connection
var hostname string
var version protocol.VersionNumber
var conf *Config
newClientSession = func(
connP connection,
_ sessionRunner,
versionP protocol.VersionNumber,
connIDP protocol.ConnectionID,
_ protocol.ConnectionID,
tlsConf *tls.Config,
configP *Config,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (quicSession, error) {
cconn = connP
hostname = tlsConf.ServerName
version = versionP
conf = configP
connID = connIDP
close(c)
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run()
return sess, nil
}
config := &Config{Versions: supportedVersionsWithoutGQUIC44}
_, err := Dial(
packetConn,
addr,
"quic.clemente.io:1337",
nil,
config,
)
Expect(err).ToNot(HaveOccurred())
Eventually(c).Should(BeClosed())
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
Expect(hostname).To(Equal("quic.clemente.io"))
Expect(version).To(Equal(config.Versions[0]))
Expect(conf.Versions).To(Equal(config.Versions))
})
Context("Public Reset handling", func() {
var (
pr []byte
hdr *wire.Header
hdrLen int
)
BeforeEach(func() {
cl.config = &Config{}
pr = wire.WritePublicReset(cl.destConnID, 1, 0)
r := bytes.NewReader(pr)
iHdr, err := wire.ParseInvariantHeader(r, 0)
Expect(err).ToNot(HaveOccurred())
hdr, err = iHdr.Parse(r, protocol.PerspectiveServer, versionGQUICFrames)
Expect(err).ToNot(HaveOccurred())
hdrLen = r.Len()
})
It("closes the session when receiving a Public Reset", func() {
cl.version = versionGQUICFrames
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().closeRemote(gomock.Any()).Do(func(err error) {
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset))
})
cl.session = sess
cl.handlePacketImpl(&receivedPacket{
remoteAddr: addr,
header: hdr,
data: pr[len(pr)-hdrLen:],
})
})
It("ignores Public Resets from the wrong remote address", func() {
cl.version = versionGQUICFrames
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls
spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678}
err := cl.handlePacketImpl(&receivedPacket{
remoteAddr: spoofedAddr,
header: hdr,
data: pr[len(pr)-hdrLen:],
})
Expect(err).To(MatchError("Received a spoofed Public Reset"))
})
It("ignores unparseable Public Resets", func() {
cl.version = versionGQUICFrames
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls
err := cl.handlePacketImpl(&receivedPacket{
remoteAddr: addr,
header: hdr,
data: pr[len(pr)-hdrLen : len(pr)-5], // cut off the last 5 bytes
})
Expect(err.Error()).To(ContainSubstring("Received a Public Reset. An error occurred parsing the packet"))
})
})
})

View file

@ -7,15 +7,12 @@ import (
"net/http"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
func main() {
verbose := flag.Bool("v", false, "verbose")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
quiet := flag.Bool("q", false, "don't print the data")
flag.Parse()
urls := flag.Args()
@ -29,14 +26,7 @@ func main() {
}
logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
roundTripper := &h2quic.RoundTripper{
QuicConfig: &quic.Config{Versions: versions},
}
roundTripper := &h2quic.RoundTripper{}
defer roundTripper.Close()
hclient := &http.Client{
Transport: roundTripper,

View file

@ -17,9 +17,7 @@ import (
_ "net/http/pprof"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
@ -123,7 +121,6 @@ func main() {
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
www := flag.String("www", "/var/www", "www data")
tcp := flag.Bool("tcp", false, "also listen on TCP")
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
flag.Parse()
logger := utils.DefaultLogger
@ -135,11 +132,6 @@ func main() {
}
logger.SetLogTimeFormat("")
versions := protocol.SupportedVersions
if *tls {
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
}
certFile := *certPath + "/fullchain.pem"
keyFile := *certPath + "/privkey.pem"
@ -159,8 +151,7 @@ func main() {
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
} else {
server := h2quic.Server{
Server: &http.Server{Addr: bCap},
QuicConfig: &quic.Config{Versions: versions},
Server: &http.Server{Addr: bCap},
}
err = server.ListenAndServeTLS(certFile, keyFile)
}

View file

@ -52,10 +52,7 @@ type client struct {
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDOmission: true,
KeepAlive: true,
}
var defaultQuicConfig = &quic.Config{KeepAlive: true}
// newClient creates a new client
func newClient(

View file

@ -1,210 +0,0 @@
package chrome_test
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync/atomic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gexec"
"testing"
)
const (
dataLen = 500 * 1024 // 500 KB
dataLongLen = 50 * 1024 * 1024 // 50 MB
)
var (
nFilesUploaded int32 // should be used atomically
doneCalled utils.AtomicBool
)
func TestChrome(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Chrome Suite")
}
func init() {
// Requires the len & num GET parameters, e.g. /uploadtest?len=100&num=1
http.HandleFunc("/uploadtest", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
response := uploadHTML
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
_, err := io.WriteString(w, response)
Expect(err).NotTo(HaveOccurred())
})
// Requires the len & num GET parameters, e.g. /downloadtest?len=100&num=1
http.HandleFunc("/downloadtest", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
response := downloadHTML
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
_, err := io.WriteString(w, response)
Expect(err).NotTo(HaveOccurred())
})
http.HandleFunc("/uploadhandler", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
l, err := strconv.Atoi(r.URL.Query().Get("len"))
Expect(err).NotTo(HaveOccurred())
defer r.Body.Close()
actual, err := ioutil.ReadAll(r.Body)
Expect(err).NotTo(HaveOccurred())
Expect(bytes.Equal(actual, testserver.GeneratePRData(l))).To(BeTrue())
atomic.AddInt32(&nFilesUploaded, 1)
})
http.HandleFunc("/done", func(w http.ResponseWriter, r *http.Request) {
doneCalled.Set(true)
})
}
var _ = AfterEach(func() {
testserver.StopQuicServer()
atomic.StoreInt32(&nFilesUploaded, 0)
doneCalled.Set(false)
})
func getChromePath() string {
if runtime.GOOS == "darwin" {
return "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
}
if path, err := exec.LookPath("google-chrome"); err == nil {
return path
}
if path, err := exec.LookPath("chromium-browser"); err == nil {
return path
}
Fail("No Chrome executable found.")
return ""
}
func chromeTest(version protocol.VersionNumber, url string, blockUntilDone func()) {
userDataDir, err := ioutil.TempDir("", "quic-go-test-chrome-dir")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(userDataDir)
path := getChromePath()
args := []string{
"--disable-gpu",
"--no-first-run=true",
"--no-default-browser-check=true",
"--user-data-dir=" + userDataDir,
"--enable-quic=true",
"--no-proxy-server=true",
"--no-sandbox",
"--origin-to-force-quic-on=quic.clemente.io:443",
fmt.Sprintf(`--host-resolver-rules=MAP quic.clemente.io:443 127.0.0.1:%s`, testserver.Port()),
fmt.Sprintf("--quic-version=QUIC_VERSION_%s", version.ToAltSvc()),
url,
}
utils.DefaultLogger.Infof("Running chrome: %s '%s'", getChromePath(), strings.Join(args, "' '"))
command := exec.Command(path, args...)
session, err := gexec.Start(command, nil, nil)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
blockUntilDone()
}
func waitForDone() {
Eventually(func() bool { return doneCalled.Get() }, 60).Should(BeTrue())
}
func waitForNUploaded(expected int) func() {
return func() {
Eventually(func() int32 {
return atomic.LoadInt32(&nFilesUploaded)
}, 60).Should(BeEquivalentTo(expected))
}
}
const commonJS = `
var buf = new ArrayBuffer(LENGTH);
var prng = new Uint8Array(buf);
var seed = 1;
for (var i = 0; i < LENGTH; i++) {
// https://en.wikipedia.org/wiki/Lehmer_random_number_generator
seed = seed * 48271 % 2147483647;
prng[i] = seed;
}
`
const uploadHTML = `
<html>
<body>
<script>
console.log("Running DL test...");
` + commonJS + `
for (var i = 0; i < NUM; i++) {
var req = new XMLHttpRequest();
req.open("POST", "/uploadhandler?len=" + LENGTH, true);
req.send(buf);
}
</script>
</body>
</html>
`
const downloadHTML = `
<html>
<body>
<script>
console.log("Running DL test...");
` + commonJS + `
function verify(data) {
if (data.length !== LENGTH) return false;
for (var i = 0; i < LENGTH; i++) {
if (data[i] !== prng[i]) return false;
}
return true;
}
var nOK = 0;
for (var i = 0; i < NUM; i++) {
let req = new XMLHttpRequest();
req.responseType = "arraybuffer";
req.open("POST", "/prdata?len=" + LENGTH, true);
req.onreadystatechange = function () {
if (req.readyState === XMLHttpRequest.DONE && req.status === 200) {
if (verify(new Uint8Array(req.response))) {
nOK++;
if (nOK === NUM) {
console.log("Done :)");
var reqDone = new XMLHttpRequest();
reqDone.open("GET", "/done");
reqDone.send();
}
}
}
};
req.send();
}
</script>
</body>
</html>
`

View file

@ -1,71 +0,0 @@
package chrome_test
import (
"fmt"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
)
var _ = Describe("Chrome tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with version %s", version), func() {
JustBeforeEach(func() {
testserver.StartQuicServer([]protocol.VersionNumber{version})
})
It("downloads a small file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLen),
waitForDone,
)
})
It("downloads a large file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLongLen),
waitForDone,
)
})
It("loads a large number of files", func() {
chromeTest(
version,
"https://quic.clemente.io/downloadtest?num=4&len=100",
waitForDone,
)
})
It("uploads a small file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLen),
waitForNUploaded(1),
)
})
It("uploads a large file", func() {
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLongLen),
waitForNUploaded(1),
)
})
It("uploads many small files", func() {
num := protocol.DefaultMaxIncomingStreams + 20
chromeTest(
version,
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=%d&len=%d", num, dataLen),
waitForNUploaded(num),
)
})
})
}
})

View file

@ -1,137 +0,0 @@
package gquic_test
import (
"bytes"
"fmt"
mrand "math/rand"
"os/exec"
"strconv"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
var _ = Describe("Drop tests", func() {
var proxy *quicproxy.QuicProxy
startProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DropPacket: dropCallback,
})
Expect(err).ToNot(HaveOccurred())
}
downloadFile := func(version protocol.VersionNumber) {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
downloadHello := func(version protocol.VersionNumber) {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/hello",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: Hello, World!\n"))
}
deterministicDropper := func(p, interval, dropInARow uint64) bool {
return (p % interval) < dropInARow
}
stochasticDropper := func(freq int) bool {
return mrand.Int63n(int64(freq)) == 0
}
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
})
for _, v := range protocol.SupportedVersions {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {
Context("during the crypto handshake", func() {
for _, d := range directions {
direction := d
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 1 && d.Is(direction)
}, version)
downloadHello(version)
})
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p == 2 && d.Is(direction)
}, version)
downloadHello(version)
})
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return d.Is(direction) && stochasticDropper(5)
}, version)
downloadHello(version)
})
}
})
Context("after the crypto handshake", func() {
for _, d := range directions {
direction := d
It(fmt.Sprintf("downloads a file when every 5th packet is dropped in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && deterministicDropper(p, 5, 1)
}, version)
downloadFile(version)
})
It(fmt.Sprintf("downloads a file when 1/5th of all packet are dropped randomly in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && stochasticDropper(5)
}, version)
downloadFile(version)
})
It(fmt.Sprintf("downloads a file when 10 packets every 100 packet are dropped in %s direction", d), func() {
startProxy(func(d quicproxy.Direction, p uint64) bool {
return p >= 10 && d.Is(direction) && deterministicDropper(p, 100, 10)
}, version)
downloadFile(version)
})
}
})
})
}
})

View file

@ -1,45 +0,0 @@
package gquic_test
import (
"fmt"
"math/rand"
"path/filepath"
"runtime"
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
var (
clientPath string
serverPath string
)
func TestIntegration(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "GQuic Tests Suite")
}
var _ = BeforeSuite(func() {
rand.Seed(GinkgoRandomSeed())
})
var _ = JustBeforeEach(func() {
testserver.StartQuicServer(nil)
})
var _ = AfterEach(testserver.StopQuicServer)
func init() {
_, thisfile, _, ok := runtime.Caller(0)
if !ok {
panic("Failed to get current path")
}
clientPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/client-%s-debug", runtime.GOOS))
serverPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/server-%s-debug", runtime.GOOS))
}

View file

@ -1,98 +0,0 @@
package gquic_test
import (
"bytes"
"fmt"
"os/exec"
"sync"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
_ "github.com/lucas-clemente/quic-clients" // download clients
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("Integration tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("gets a simple file", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/hello",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 5).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: Hello, World!\n"))
})
It("posts and reads a body", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"--body=foo",
"https://quic.clemente.io/echo",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 5).Should(Exit(0))
Expect(session.Out).To(Say(":status 200"))
Expect(session.Out).To(Say("body: foo\n"))
})
It("gets a file", func() {
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 10).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
})
It("gets many copies of a file in parallel", func() {
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+testserver.Port(),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}()
}
wg.Wait()
})
})
}
})

View file

@ -1,95 +0,0 @@
package gquic_test
import (
"bytes"
"fmt"
"math/rand"
"os/exec"
"strconv"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gexec"
)
// get a random duration between min and max
func getRandomDuration(min, max time.Duration) time.Duration {
return min + time.Duration(rand.Int63n(int64(max-min)))
}
var _ = Describe("Random Duration Generator", func() {
It("gets a random RTT", func() {
var min time.Duration = time.Hour
var max time.Duration
var sum time.Duration
rep := 10000
for i := 0; i < rep; i++ {
val := getRandomDuration(100*time.Millisecond, 500*time.Millisecond)
sum += val
if val < min {
min = val
}
if val > max {
max = val
}
}
avg := sum / time.Duration(rep)
Expect(avg).To(BeNumerically("~", 300*time.Millisecond, 5*time.Millisecond))
Expect(min).To(BeNumerically(">=", 100*time.Millisecond))
Expect(min).To(BeNumerically("<", 105*time.Millisecond))
Expect(max).To(BeNumerically(">", 495*time.Millisecond))
Expect(max).To(BeNumerically("<=", 500*time.Millisecond))
})
})
var _ = Describe("Random RTT", func() {
var proxy *quicproxy.QuicProxy
runRTTTest := func(minRtt, maxRtt time.Duration, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
return getRandomDuration(minRtt, maxRtt)
},
})
Expect(err).ToNot(HaveOccurred())
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
AfterEach(func() {
err := proxy.Close()
Expect(err).ToNot(HaveOccurred())
time.Sleep(time.Millisecond)
})
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("gets a file a random RTT between 10ms and 30ms", func() {
runRTTTest(10*time.Millisecond, 30*time.Millisecond, version)
})
})
}
})

View file

@ -1,66 +0,0 @@
package gquic_test
import (
"bytes"
"fmt"
"os/exec"
"strconv"
"time"
_ "github.com/lucas-clemente/quic-clients" // download clients
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("non-zero RTT", func() {
var proxy *quicproxy.QuicProxy
runRTTTest := func(rtt time.Duration, version protocol.VersionNumber) {
var err error
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
RemoteAddr: "localhost:" + testserver.Port(),
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
return rtt / 2
},
})
Expect(err).ToNot(HaveOccurred())
command := exec.Command(
clientPath,
"--quic-version="+version.ToAltSvc(),
"--host=127.0.0.1",
"--port="+strconv.Itoa(proxy.LocalPort()),
"https://quic.clemente.io/prdata",
)
session, err := Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
defer session.Kill()
Eventually(session, 20).Should(Exit(0))
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
}
AfterEach(func() {
err := proxy.Close()
Expect(err).ToNot(HaveOccurred())
time.Sleep(time.Millisecond)
})
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
Context(fmt.Sprintf("with QUIC version %s", version), func() {
roundTrips := [...]int{10, 50, 100, 200}
for _, rtt := range roundTrips {
It(fmt.Sprintf("gets a 500kB file with %dms RTT", rtt), func() {
runRTTTest(time.Duration(rtt)*time.Millisecond, version)
})
}
})
}
})

View file

@ -1,218 +0,0 @@
package gquic_test
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
mrand "math/rand"
"net/http"
"os"
"os/exec"
"path/filepath"
"strconv"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
. "github.com/onsi/gomega/gexec"
)
var _ = Describe("Server tests", func() {
for i := range protocol.SupportedVersions {
version := protocol.SupportedVersions[i]
var (
serverPort string
tmpDir string
session *Session
client *http.Client
)
generateCA := func() (*rsa.PrivateKey, *x509.Certificate) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
templateRoot := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, templateRoot, templateRoot, &key.PublicKey, key)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return key, cert
}
// prepare the file such that it can be by the quic_server
// some HTTP headers neeed to be prepended, see https://www.chromium.org/quic/playing-with-quic
createDownloadFile := func(filename string, data []byte) {
dataDir := filepath.Join(tmpDir, "quic.clemente.io")
err := os.Mkdir(dataDir, 0777)
Expect(err).ToNot(HaveOccurred())
f, err := os.Create(filepath.Join(dataDir, filename))
Expect(err).ToNot(HaveOccurred())
defer f.Close()
_, err = f.Write([]byte("HTTP/1.1 200 OK\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("Content-Type: text/html\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("X-Original-Url: https://quic.clemente.io:" + serverPort + "/" + filename + "\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write([]byte("Content-Length: " + strconv.Itoa(len(data)) + "\n\n"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write(data)
Expect(err).ToNot(HaveOccurred())
}
// download files must be create *before* the quic_server is started
// the quic_server reads its data dir on startup, and only serves those files that were already present then
startServer := func(version protocol.VersionNumber) {
defer GinkgoRecover()
var err error
command := exec.Command(
serverPath,
"--quic_response_cache_dir="+filepath.Join(tmpDir, "quic.clemente.io"),
"--key_file="+filepath.Join(tmpDir, "key.pkcs8"),
"--certificate_file="+filepath.Join(tmpDir, "cert.pem"),
"--quic-version="+strconv.Itoa(int(version)),
"--port="+serverPort,
)
session, err = Start(command, nil, GinkgoWriter)
Expect(err).NotTo(HaveOccurred())
}
stopServer := func() {
session.Kill()
}
BeforeEach(func() {
serverPort = strconv.Itoa(20000 + int(mrand.Int31n(10000)))
var err error
tmpDir, err = ioutil.TempDir("", "quic-server-certs")
Expect(err).ToNot(HaveOccurred())
// generate an RSA key pair for the server
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
// save the private key in PKCS8 format to disk (required by quic_server)
pkcs8key, err := asn1.Marshal(struct { // copied from the x509 package
Version int
Algo pkix.AlgorithmIdentifier
PrivateKey []byte
}{
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
Algo: pkix.AlgorithmIdentifier{
Algorithm: asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1},
Parameters: asn1.RawValue{Tag: 5},
},
})
Expect(err).ToNot(HaveOccurred())
f, err := os.Create(filepath.Join(tmpDir, "key.pkcs8"))
Expect(err).ToNot(HaveOccurred())
_, err = f.Write(pkcs8key)
Expect(err).ToNot(HaveOccurred())
f.Close()
// generate a Certificate Authority
// this CA is used to sign the server's key
// it is set as a valid CA in the QUIC client
rootKey, CACert := generateCA()
// generate the server certificate
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-30 * time.Minute),
NotAfter: time.Now().Add(30 * time.Minute),
Subject: pkix.Name{CommonName: "quic.clemente.io"},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, CACert, &key.PublicKey, rootKey)
Expect(err).ToNot(HaveOccurred())
// save the certificate to disk
certOut, err := os.Create(filepath.Join(tmpDir, "cert.pem"))
Expect(err).ToNot(HaveOccurred())
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certOut.Close()
// prepare the h2quic.client
certPool := x509.NewCertPool()
certPool.AddCert(CACert)
client = &http.Client{
Transport: &h2quic.RoundTripper{
TLSClientConfig: &tls.Config{RootCAs: certPool},
QuicConfig: &quic.Config{
Versions: []protocol.VersionNumber{version},
},
},
}
})
AfterEach(func() {
Expect(tmpDir).ToNot(BeEmpty())
err := os.RemoveAll(tmpDir)
Expect(err).ToNot(HaveOccurred())
tmpDir = ""
})
Context(fmt.Sprintf("with QUIC version %s", version), func() {
It("downloads a hello", func() {
data := []byte("Hello world!\n")
createDownloadFile("hello", data)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/hello")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(data))
})
It("downloads a small file", func() {
createDownloadFile("file.dat", testserver.PRData)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRData))
})
It("downloads a large file", func() {
createDownloadFile("file.dat", testserver.PRDataLong)
startServer(version)
defer stopServer()
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
Expect(err).ToNot(HaveOccurred())
Expect(rsp.StatusCode).To(Equal(200))
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 20*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRDataLong))
})
})
}
})

View file

@ -13,6 +13,7 @@ import (
"github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes"
@ -21,8 +22,7 @@ import (
var _ = Describe("Client tests", func() {
var client *http.Client
// also run some tests with the TLS handshake
versions := append(protocol.SupportedVersions, protocol.VersionTLS)
versions := protocol.SupportedVersions
BeforeEach(func() {
err := os.Setenv("HOSTALIASES", "quic.clemente.io 127.0.0.1")

View file

@ -60,42 +60,32 @@ var _ = Describe("Connection ID lengths tests", func() {
Expect(data).To(Equal(testserver.PRData))
}
Context("IETF QUIC", func() {
It("downloads a file using a 0-byte connection ID for the client", func() {
serverConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
clientConf := &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
It("downloads a file using a 0-byte connection ID for the client", func() {
serverConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
clientConf := &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
It("downloads a file when both client and server use a random connection ID length", func() {
serverConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
clientConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
Context("gQUIC", func() {
It("downloads a file using a 0-byte connection ID for the client", func() {
ln := runServer(&quic.Config{})
defer ln.Close()
runClient(ln.Addr(), &quic.Config{RequestConnectionIDOmission: true})
})
It("downloads a file when both client and server use a random connection ID length", func() {
serverConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
clientConf := &quic.Config{
ConnectionIDLength: randomConnIDLen(),
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}
ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
})

View file

@ -87,128 +87,57 @@ var _ = Describe("Handshake RTT tests", func() {
expectDurationInRTTs(1)
})
Context("gQUIC", func() {
// 1 RTT for verifying the source address
// 1 RTT to become secure
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
var clientConfig *quic.Config
var clientTLSConfig *tls.Config
It("does version negotiation in 1 RTT, IETF QUIC => gQUIC", func() {
clientConfig := &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS, protocol.SupportedVersions[0]},
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
&tls.Config{InsecureSkipVerify: true},
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(4)
})
It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
expectDurationInRTTs(3)
})
BeforeEach(func() {
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
clientConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
clientTLSConfig = &tls.Config{
InsecureSkipVerify: true,
ServerName: "quic.clemente.io",
}
})
Context("IETF QUIC", func() {
var clientConfig *quic.Config
var clientTLSConfig *tls.Config
// 1 RTT for verifying the source address
// 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
BeforeEach(func() {
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
clientConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
clientTLSConfig = &tls.Config{
InsecureSkipVerify: true,
ServerName: "quic.clemente.io",
}
})
It("is forward-secure after 1 RTTs when the server doesn't require a Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(1)
})
// 1 RTT for verifying the source address
// 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("does version negotiation in 1 RTT, gQUIC => IETF QUIC", func() {
clientConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[0], protocol.VersionTLS}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
It("is forward-secure after 1 RTTs when the server doesn't require a Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(1)
})
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return false
}
clientConfig.HandshakeTimeout = 500 * time.Millisecond
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
})
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
return false
}
clientConfig.HandshakeTimeout = 500 * time.Millisecond
runServerAndProxy()
_, err := quic.DialAddr(
proxy.LocalAddr().String(),
clientTLSConfig,
clientConfig,
)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
})
})

View file

@ -91,7 +91,7 @@ var _ = Describe("Handshake tests", func() {
})
Context("Certifiate validation", func() {
for _, v := range []protocol.VersionNumber{protocol.Version39, protocol.VersionTLS} {
for _, v := range protocol.SupportedVersions {
version := v
Context(fmt.Sprintf("using %s", version), func() {

View file

@ -18,15 +18,9 @@ import (
)
var _ = Describe("Multiplexing", func() {
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
for _, v := range protocol.SupportedVersions {
version := v
// gQUIC 44 uses 0 byte connection IDs for packets sent to the client
// It's not possible to do demultiplexing.
if v == protocol.Version44 {
continue
}
Context(fmt.Sprintf("with QUIC version %s", version), func() {
runServer := func(ln quic.Listener) {
go func() {
@ -143,10 +137,6 @@ var _ = Describe("Multiplexing", func() {
Context("multiplexing server and client on the same conn", func() {
It("connects to itself", func() {
if version != protocol.VersionTLS {
Skip("Connecting to itself only works with IETF QUIC.")
}
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err := net.ListenUDP("udp", addr)

View file

@ -18,7 +18,7 @@ import (
)
var _ = Describe("non-zero RTT", func() {
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
for _, v := range protocol.SupportedVersions {
version := v
Context(fmt.Sprintf("with QUIC version %s", version), func() {

View file

@ -16,15 +16,6 @@ type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
const (
// VersionGQUIC39 is gQUIC version 39.
VersionGQUIC39 = protocol.Version39
// VersionGQUIC43 is gQUIC version 43.
VersionGQUIC43 = protocol.Version43
// VersionGQUIC44 is gQUIC version 44.
VersionGQUIC44 = protocol.Version44
)
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
@ -164,11 +155,7 @@ type Config struct {
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []VersionNumber
// Ask the server to omit the connection ID sent in the Public Header.
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDOmission bool
// The length of the connection ID in bytes. Only valid for IETF QUIC.
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
@ -201,7 +188,6 @@ type Config struct {
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// This value doesn't have any effect in Google QUIC.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.

View file

@ -27,7 +27,6 @@ type SentPacketHandler interface {
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() *Packet
DequeueProbePacket() (*Packet, error)

View file

@ -16,8 +16,6 @@ func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f wire.Frame) bool {
switch f.(type) {
case *wire.StopWaitingFrame:
return false
case *wire.AckFrame:
return false
default:

View file

@ -11,10 +11,8 @@ import (
var _ = Describe("retransmittable frames", func() {
for fl, el := range map[wire.Frame]bool{
&wire.AckFrame{}: false,
&wire.StopWaitingFrame{}: false,
&wire.BlockedFrame{}: true,
&wire.ConnectionCloseFrame{}: true,
&wire.GoawayFrame{}: true,
&wire.PingFrame{}: true,
&wire.RstStreamFrame{}: true,
&wire.StreamFrame{}: true,

View file

@ -45,8 +45,7 @@ type sentPacketHandler struct {
lowestPacketNotConfirmedAcked protocol.PacketNumber
largestSentBeforeRTO protocol.PacketNumber
packetHistory *sentPacketHistory
stopWaitingManager stopWaitingManager
packetHistory *sentPacketHistory
retransmissionQueue []*Packet
@ -90,12 +89,11 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
)
return &sentPacketHandler{
packetHistory: newSentPacketHistory(),
stopWaitingManager: stopWaitingManager{},
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
packetHistory: newSentPacketHistory(),
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
@ -110,15 +108,13 @@ func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure ||
packet.EncryptionLevel == protocol.Encryption1RTT {
if packet.EncryptionLevel == protocol.Encryption1RTT {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.EncryptionForwardSecure &&
p.EncryptionLevel != protocol.Encryption1RTT {
if p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
}
return true, nil
@ -169,8 +165,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
if packet.EncryptionLevel != protocol.EncryptionForwardSecure &&
packet.EncryptionLevel != protocol.Encryption1RTT {
if packet.EncryptionLevel != protocol.Encryption1RTT {
h.lastSentHandshakePacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
@ -217,12 +212,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
priorInFlight := h.bytesInFlight
for _, p := range ackedPackets {
// TODO(#1534): also check the encryption level for IETF QUIC
if !h.version.UsesTLS() {
if encLevel < p.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
}
}
// TODO(#1534): check the encryption level
// if encLevel < p.EncryptionLevel {
// return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
// }
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
@ -243,8 +237,6 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
return nil
}
@ -530,10 +522,6 @@ func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
}
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
@ -592,9 +580,7 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted &&
p.EncryptionLevel != protocol.EncryptionForwardSecure &&
p.EncryptionLevel != protocol.Encryption1RTT {
if p.canBeRetransmitted && p.EncryptionLevel != protocol.Encryption1RTT {
handshakePackets = append(handshakePackets, p)
}
return true, nil
@ -616,7 +602,6 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
return err
}
h.retransmissionQueue = append(h.retransmissionQueue, p)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
return nil
}

View file

@ -15,7 +15,7 @@ import (
func retransmittablePacket(p *Packet) *Packet {
if p.EncryptionLevel == protocol.EncryptionUnspecified {
p.EncryptionLevel = protocol.EncryptionForwardSecure
p.EncryptionLevel = protocol.Encryption1RTT
}
if p.Length == 0 {
p.Length = 1
@ -37,7 +37,7 @@ func nonRetransmittablePacket(p *Packet) *Packet {
func handshakePacket(p *Packet) *Packet {
p = retransmittablePacket(p)
p.EncryptionLevel = protocol.EncryptionUnencrypted
p.EncryptionLevel = protocol.EncryptionInitial
return p
}
@ -123,8 +123,8 @@ var _ = Describe("SentPacketHandler", func() {
It("stores the sent time of handshake packets", func() {
sendTime := time.Now().Add(-time.Minute)
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1, SendTime: sendTime, EncryptionLevel: protocol.EncryptionUnencrypted}))
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2, SendTime: sendTime.Add(time.Hour), EncryptionLevel: protocol.EncryptionForwardSecure}))
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1, SendTime: sendTime, EncryptionLevel: protocol.EncryptionInitial}))
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2, SendTime: sendTime.Add(time.Hour), EncryptionLevel: protocol.Encryption1RTT}))
Expect(handler.lastSentHandshakePacketTime).To(Equal(sendTime))
})
@ -205,7 +205,7 @@ var _ = Describe("SentPacketHandler", func() {
ack := &wire.AckFrame{
AckRanges: []wire.AckRange{{Smallest: 10, Largest: 12}},
}
err := handler.ReceivedAck(ack, 1337, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now())
Expect(err).To(MatchError("InvalidAckData: Received an ACK for a skipped packet number"))
})
@ -216,7 +216,7 @@ var _ = Describe("SentPacketHandler", func() {
{Smallest: 10, Largest: 10},
},
}
err := handler.ReceivedAck(ack, 1337, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestAcked).ToNot(BeZero())
})
@ -237,7 +237,7 @@ var _ = Describe("SentPacketHandler", func() {
Context("ACK validation", func() {
It("accepts ACKs sent in packet 0", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}}
err := handler.ReceivedAck(ack, 0, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 0, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(5)))
})
@ -245,12 +245,12 @@ var _ = Describe("SentPacketHandler", func() {
It("rejects duplicate ACKs", func() {
ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}}
ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 4}}}
err := handler.ReceivedAck(ack1, 1337, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack1, 1337, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(3)))
// this wouldn't happen in practice
// for testing purposes, we pretend send a different ACK frame in a duplicated packet, to be able to verify that it actually doesn't get processed
err = handler.ReceivedAck(ack2, 1337, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack2, 1337, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(3)))
})
@ -259,28 +259,28 @@ var _ = Describe("SentPacketHandler", func() {
// acks packets 0, 1, 2, 3
ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}}
ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 4}}}
err := handler.ReceivedAck(ack1, 1337, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack1, 1337, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
// this wouldn't happen in practive
// a receiver wouldn't send an ACK for a lower largest acked in a packet sent later
err = handler.ReceivedAck(ack2, 1337-1, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack2, 1337-1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(3)))
})
It("rejects ACKs with a too high LargestAcked packet number", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 9999}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).To(MatchError("InvalidAckData: Received ACK for an unsent package"))
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
})
It("ignores repeated ACKs", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 3}}}
err := handler.ReceivedAck(ack, 1337, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7)))
err = handler.ReceivedAck(ack, 1337+1, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack, 1337+1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(3)))
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7)))
@ -290,7 +290,7 @@ var _ = Describe("SentPacketHandler", func() {
Context("acks and nacks the right packets", func() {
It("adjusts the LargestAcked, and adjusts the bytes in flight", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(5)))
expectInPacketHistory([]protocol.PacketNumber{6, 7, 8, 9})
@ -299,7 +299,7 @@ var _ = Describe("SentPacketHandler", func() {
It("acks packet 0", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 0}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(getPacket(0)).To(BeNil())
expectInPacketHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 6, 7, 8, 9})
@ -312,14 +312,14 @@ var _ = Describe("SentPacketHandler", func() {
{Smallest: 1, Largest: 3},
},
}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{0, 4, 5})
})
It("does not ack packets below the LowestAcked", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 8}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{0, 1, 2, 9})
})
@ -333,7 +333,7 @@ var _ = Describe("SentPacketHandler", func() {
{Smallest: 1, Largest: 1},
},
}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{0, 2, 4, 5, 8})
})
@ -345,12 +345,12 @@ var _ = Describe("SentPacketHandler", func() {
{Smallest: 1, Largest: 2},
},
}
err := handler.ReceivedAck(ack1, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack1, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{0, 3, 7, 8, 9})
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(5)))
ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} // now ack 3
err = handler.ReceivedAck(ack2, 2, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack2, 2, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{0, 7, 8, 9})
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4)))
@ -363,12 +363,12 @@ var _ = Describe("SentPacketHandler", func() {
{Smallest: 0, Largest: 2},
},
}
err := handler.ReceivedAck(ack1, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack1, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{3, 7, 8, 9})
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4)))
ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 7}}}
err = handler.ReceivedAck(ack2, 2, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack2, 2, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
expectInPacketHistory([]protocol.PacketNumber{8, 9})
@ -376,7 +376,7 @@ var _ = Describe("SentPacketHandler", func() {
It("processes an ACK that contains old ACK ranges", func() {
ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}}
err := handler.ReceivedAck(ack1, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack1, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{0, 7, 8, 9})
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4)))
@ -387,7 +387,7 @@ var _ = Describe("SentPacketHandler", func() {
{Smallest: 1, Largest: 1},
},
}
err = handler.ReceivedAck(ack2, 2, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack2, 2, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{0, 7, 9})
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(3)))
@ -403,15 +403,15 @@ var _ = Describe("SentPacketHandler", func() {
getPacket(6).SendTime = now.Add(-1 * time.Minute)
// Now, check that the proper times are used when calculating the deltas
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).NotTo(HaveOccurred())
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second))
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}}
err = handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack, 2, protocol.Encryption1RTT, time.Now())
Expect(err).NotTo(HaveOccurred())
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}}
err = handler.ReceivedAck(ack, 3, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack, 3, protocol.Encryption1RTT, time.Now())
Expect(err).NotTo(HaveOccurred())
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 1*time.Minute, 1*time.Second))
})
@ -425,7 +425,7 @@ var _ = Describe("SentPacketHandler", func() {
AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}},
DelayTime: 5 * time.Minute,
}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).NotTo(HaveOccurred())
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
})
@ -446,27 +446,27 @@ var _ = Describe("SentPacketHandler", func() {
})
It("determines which ACK we have received an ACK for", func() {
err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 15}}}, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 15}}}, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
})
It("doesn't do anything when the acked packet didn't contain an ACK", func() {
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 15, Largest: 15}}}
err = handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack, 2, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
})
It("doesn't decrease the value", func() {
err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 14, Largest: 14}}}, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 14, Largest: 14}}}, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
err = handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}, 2, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}, 2, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
})
@ -492,7 +492,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11)))
// ack 5
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 5, Largest: 5}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
expectInPacketHistory([]protocol.PacketNumber{6})
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11)))
@ -510,36 +510,16 @@ var _ = Describe("SentPacketHandler", func() {
{Smallest: 5, Largest: 5},
},
}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.packetHistory.Len()).To(BeZero())
Expect(handler.bytesInFlight).To(BeZero())
})
})
Context("Retransmission handling", func() {
It("does not dequeue a packet if no ack has been received", func() {
handler.SentPacket(&Packet{PacketNumber: 1})
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
})
Context("STOP_WAITINGs", func() {
It("gets a STOP_WAITING frame", func() {
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1}))
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2}))
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 3}}}
err := handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 4}))
})
It("gets a STOP_WAITING frame after queueing a retransmission", func() {
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 5}))
handler.queuePacketForRetransmission(getPacket(5))
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6}))
})
})
It("does not dequeue a packet if no ack has been received", func() {
handler.SentPacket(&Packet{PacketNumber: 1})
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
})
Context("congestion", func() {
@ -580,7 +560,7 @@ var _ = Describe("SentPacketHandler", func() {
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2}))
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, rcvTime)
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, rcvTime)
Expect(err).NotTo(HaveOccurred())
})
@ -618,7 +598,7 @@ var _ = Describe("SentPacketHandler", func() {
)
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 5}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 5, Largest: 5}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, rcvTime)
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, rcvTime)
Expect(err).ToNot(HaveOccurred())
})
@ -641,7 +621,7 @@ var _ = Describe("SentPacketHandler", func() {
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(3)),
)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
})
@ -657,11 +637,11 @@ var _ = Describe("SentPacketHandler", func() {
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)),
)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
// don't EXPECT any further calls to the congestion controller
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}}
err = handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack, 2, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
})
@ -679,7 +659,7 @@ var _ = Describe("SentPacketHandler", func() {
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)),
)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now().Add(-30*time.Minute))
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute))
Expect(err).ToNot(HaveOccurred())
// receive the second ACK
gomock.InOrder(
@ -688,7 +668,7 @@ var _ = Describe("SentPacketHandler", func() {
cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)),
)
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}}
err = handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack, 2, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
})
@ -778,7 +758,7 @@ var _ = Describe("SentPacketHandler", func() {
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 10}))
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 11}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.GetAlarmTimeout()).To(BeZero())
})
@ -906,7 +886,7 @@ var _ = Describe("SentPacketHandler", func() {
// This verifies the RTO.
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 3}}}
err = handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.packetHistory.Len()).To(BeZero())
Expect(handler.bytesInFlight).To(BeZero())
@ -945,7 +925,7 @@ var _ = Describe("SentPacketHandler", func() {
// This verifies the RTO.
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 6}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}}
err = handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err = handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
Expect(handler.packetHistory.Len()).To(BeZero())
Expect(handler.bytesInFlight).To(BeZero())
@ -960,7 +940,7 @@ var _ = Describe("SentPacketHandler", func() {
handler.OnAlarm() // RTO
handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(&Packet{PacketNumber: 6})}, 5)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 5, Largest: 5}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
Expect(err).ToNot(HaveOccurred())
err = handler.OnAlarm()
Expect(err).ToNot(HaveOccurred())
@ -984,7 +964,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.lossTime.IsZero()).To(BeTrue())
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, now)
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, now)
Expect(err).NotTo(HaveOccurred())
Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil())
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
@ -1001,7 +981,7 @@ var _ = Describe("SentPacketHandler", func() {
Expect(handler.lossTime.IsZero()).To(BeTrue())
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, now.Add(-time.Second))
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, now.Add(-time.Second))
Expect(err).NotTo(HaveOccurred())
Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second))
@ -1033,7 +1013,7 @@ var _ = Describe("SentPacketHandler", func() {
handler.SentPacket(handshakePacket(&Packet{PacketNumber: 3, SendTime: sendTime}))
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, now)
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, now)
// RTT is now 1 minute
Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Minute))
Expect(err).NotTo(HaveOccurred())
@ -1055,19 +1035,19 @@ var _ = Describe("SentPacketHandler", func() {
PIt("rejects an ACK that acks packets with a higher encryption level", func() {
handler.SentPacket(&Packet{
PacketNumber: 13,
EncryptionLevel: protocol.EncryptionForwardSecure,
EncryptionLevel: protocol.Encryption1RTT,
Frames: []wire.Frame{&streamFrame},
Length: 1,
})
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}
err := handler.ReceivedAck(ack, 1, protocol.EncryptionSecure, time.Now())
err := handler.ReceivedAck(ack, 1, protocol.EncryptionHandshake, time.Now())
Expect(err).To(MatchError("Received ACK with encryption level encrypted (not forward-secure) that acks a packet 13 (encryption level forward-secure)"))
})
It("deletes non forward-secure packets when the handshake completes", func() {
It("deletes non handshake packets when the handshake completes", func() {
for i := protocol.PacketNumber(1); i <= 6; i++ {
p := retransmittablePacket(&Packet{PacketNumber: i})
p.EncryptionLevel = protocol.EncryptionSecure
p.EncryptionLevel = protocol.EncryptionHandshake
handler.SentPacket(p)
}
handler.queuePacketForRetransmission(getPacket(1))

View file

@ -35,8 +35,7 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
}
if p.canBeRetransmitted {
h.numOutstandingPackets++
if p.EncryptionLevel != protocol.EncryptionForwardSecure &&
p.EncryptionLevel != protocol.Encryption1RTT {
if p.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets++
}
}
@ -107,8 +106,7 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel != protocol.EncryptionForwardSecure &&
el.Value.EncryptionLevel != protocol.Encryption1RTT {
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
@ -149,8 +147,7 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel != protocol.EncryptionForwardSecure &&
el.Value.EncryptionLevel != protocol.Encryption1RTT {
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")

View file

@ -202,7 +202,7 @@ var _ = Describe("SentPacketHistory", func() {
It("says if it has outstanding handshake packets", func() {
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionUnencrypted,
EncryptionLevel: protocol.EncryptionInitial,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
@ -212,7 +212,7 @@ var _ = Describe("SentPacketHistory", func() {
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionForwardSecure,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
@ -221,7 +221,7 @@ var _ = Describe("SentPacketHistory", func() {
It("doesn't consider non-retransmittable packets as outstanding", func() {
hist.SentPacket(&Packet{
EncryptionLevel: protocol.EncryptionUnencrypted,
EncryptionLevel: protocol.EncryptionInitial,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
Expect(hist.HasOutstandingPackets()).To(BeFalse())
@ -230,7 +230,7 @@ var _ = Describe("SentPacketHistory", func() {
It("accounts for deleted handshake packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 5,
EncryptionLevel: protocol.EncryptionSecure,
EncryptionLevel: protocol.EncryptionHandshake,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
@ -242,7 +242,7 @@ var _ = Describe("SentPacketHistory", func() {
It("accounts for deleted packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
@ -254,7 +254,7 @@ var _ = Describe("SentPacketHistory", func() {
It("doesn't count handshake packets marked as non-retransmittable", func() {
hist.SentPacket(&Packet{
PacketNumber: 5,
EncryptionLevel: protocol.EncryptionUnencrypted,
EncryptionLevel: protocol.EncryptionInitial,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
@ -266,7 +266,7 @@ var _ = Describe("SentPacketHistory", func() {
It("doesn't count packets marked as non-retransmittable", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
})
Expect(hist.HasOutstandingPackets()).To(BeTrue())
@ -278,12 +278,12 @@ var _ = Describe("SentPacketHistory", func() {
It("counts the number of packets", func() {
hist.SentPacket(&Packet{
PacketNumber: 10,
EncryptionLevel: protocol.EncryptionForwardSecure,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
})
hist.SentPacket(&Packet{
PacketNumber: 11,
EncryptionLevel: protocol.EncryptionForwardSecure,
EncryptionLevel: protocol.Encryption1RTT,
canBeRetransmitted: true,
})
err := hist.Remove(11)

View file

@ -1,43 +0,0 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
type stopWaitingManager struct {
largestLeastUnackedSent protocol.PacketNumber
nextLeastUnacked protocol.PacketNumber
lastStopWaitingFrame *wire.StopWaitingFrame
}
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
if force {
return s.lastStopWaitingFrame
}
return nil
}
s.largestLeastUnackedSent = s.nextLeastUnacked
swf := &wire.StopWaitingFrame{
LeastUnacked: s.nextLeastUnacked,
}
s.lastStopWaitingFrame = swf
return swf
}
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
largestAcked := ack.LargestAcked()
if largestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = largestAcked + 1
}
}
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
if p >= s.nextLeastUnacked {
s.nextLeastUnacked = p + 1
}
}

View file

@ -1,55 +0,0 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("StopWaitingManager", func() {
var manager *stopWaitingManager
BeforeEach(func() {
manager = &stopWaitingManager{}
})
It("returns nil in the beginning", func() {
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
Expect(manager.GetStopWaitingFrame(true)).To(BeNil())
})
It("returns a StopWaitingFrame, when a new ACK arrives", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
It("does not decrease the LeastUnacked", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 9}}})
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
It("does not send the same StopWaitingFrame twice", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
})
It("gets the same StopWaitingFrame twice, if forced", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
})
It("increases the LeastUnacked when a retransmission is queued", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.QueuedRetransmissionForPacketNumber(20)
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 21}))
})
It("does not decrease the LeastUnacked when a retransmission is queued", func() {
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
manager.QueuedRetransmissionForPacketNumber(9)
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
})
})

View file

@ -1,72 +0,0 @@
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/lucas-clemente/aes12"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadAESGCM12 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
var _ AEAD = &aeadAESGCM12{}
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
//
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
// tag size, and couples the cipher and aes packages closely.
// See https://github.com/lucas-clemente/aes12.
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
}
encrypterCipher, err := aes12.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := aes12.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes12.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := aes12.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM12{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}
func (aead *aeadAESGCM12) Overhead() int {
return aead.encrypter.Overhead()
}

View file

@ -1,69 +0,0 @@
package crypto
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("AES-GCM", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
keyAlice = make([]byte, 16)
keyBob = make([]byte, 16)
ivAlice = make([]byte, 4)
ivBob = make([]byte, 4)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADAESGCM12(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + bob.Overhead()))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
var err error
e := "AES-GCM: expected 16-byte keys and 4-byte IVs"
_, err = NewAEADAESGCM12(keyBob[1:], keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADAESGCM12(keyBob, keyAlice[1:], ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})

View file

@ -1,48 +0,0 @@
package crypto
import (
"fmt"
"hash/fnv"
"github.com/hashicorp/golang-lru"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
compressedCertsCache *lru.Cache
)
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
// Hash all inputs
hasher := fnv.New64a()
for _, v := range chain {
hasher.Write(v)
}
hasher.Write(pCommonSetHashes)
hasher.Write(pCachedHashes)
hash := hasher.Sum64()
var result []byte
resultI, isCached := compressedCertsCache.Get(hash)
if isCached {
result = resultI.([]byte)
} else {
var err error
result, err = compressChain(chain, pCommonSetHashes, pCachedHashes)
if err != nil {
return nil, err
}
compressedCertsCache.Add(hash, result)
}
return result, nil
}
func init() {
var err error
compressedCertsCache, err = lru.New(protocol.NumCachedCertificates)
if err != nil {
panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error()))
}
}

View file

@ -1,51 +0,0 @@
package crypto
import (
lru "github.com/hashicorp/golang-lru"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Certificate cache", func() {
BeforeEach(func() {
var err error
compressedCertsCache, err = lru.New(2)
Expect(err).NotTo(HaveOccurred())
})
It("gives a compressed cert", func() {
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
expected, err := compressChain(chain, nil, nil)
Expect(err).NotTo(HaveOccurred())
compressed, err := getCompressedCert(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(compressed).To(Equal(expected))
})
It("gets the same result multiple times", func() {
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
compressed, err := getCompressedCert(chain, nil, nil)
Expect(err).NotTo(HaveOccurred())
compressed2, err := getCompressedCert(chain, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(compressed).To(Equal(compressed2))
})
It("stores cached values", func() {
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
_, err := getCompressedCert(chain, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(compressedCertsCache.Len()).To(Equal(1))
Expect(compressedCertsCache.Contains(uint64(3838929964809501833))).To(BeTrue())
})
It("evicts old values", func() {
_, err := getCompressedCert([][]byte{{0x00}}, nil, nil)
Expect(err).NotTo(HaveOccurred())
_, err = getCompressedCert([][]byte{{0x01}}, nil, nil)
Expect(err).NotTo(HaveOccurred())
_, err = getCompressedCert([][]byte{{0x02}}, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(compressedCertsCache.Len()).To(Equal(2))
})
})

View file

@ -1,113 +0,0 @@
package crypto
import (
"crypto/tls"
"errors"
"strings"
)
// A CertChain holds a certificate and a private key
type CertChain interface {
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
GetLeafCert(sni string) ([]byte, error)
}
// proofSource stores a key and a certificate for the server proof
type certChain struct {
config *tls.Config
}
var _ CertChain = &certChain{}
var errNoMatchingCertificate = errors.New("no matching certificate found")
// NewCertChain loads the key and cert from files
func NewCertChain(tlsConfig *tls.Config) CertChain {
return &certChain{config: tlsConfig}
}
// SignServerProof signs CHLO and server config for use in the server proof
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return signServerProof(cert, chlo, serverConfigData)
}
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes)
}
// GetLeafCert gets the leaf certificate
func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return cert.Certificate[0], nil
}
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
conf := c.config
conf, err := maybeGetConfigForClient(conf, sni)
if err != nil {
return nil, err
}
// The rest of this function is mostly copied from crypto/tls.getCertificate
if conf.GetCertificate != nil {
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil {
return cert, err
}
}
if len(conf.Certificates) == 0 {
return nil, errNoMatchingCertificate
}
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
// There's only one choice, so no point doing any work.
return &conf.Certificates[0], nil
}
name := strings.ToLower(sni)
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
if cert, ok := conf.NameToCertificate[name]; ok {
return cert, nil
}
// try replacing labels in the name with wildcards until we get a
// match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := conf.NameToCertificate[candidate]; ok {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return &conf.Certificates[0], nil
}
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}

View file

@ -1,148 +0,0 @@
package crypto
import (
"bytes"
"compress/flate"
"compress/zlib"
"crypto/tls"
"reflect"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Proof", func() {
var (
cc *certChain
config *tls.Config
cert tls.Certificate
)
BeforeEach(func() {
cert = testdata.GetCertificate()
config = &tls.Config{}
cc = NewCertChain(config).(*certChain)
})
Context("certificate compression", func() {
It("compresses certs", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert)
z.Close()
kd := &certChain{
config: &tls.Config{
Certificates: []tls.Certificate{
{Certificate: [][]byte{cert}},
},
},
}
certCompressed, err := kd.GetCertsCompressed("", nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(certCompressed).To(Equal(append([]byte{
0x01, 0x00,
0x08, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("errors when it can't retrieve a certificate", func() {
_, err := cc.GetCertsCompressed("invalid domain", nil, nil)
Expect(err).To(MatchError(errNoMatchingCertificate))
})
})
Context("signing server configs", func() {
It("errors when it can't retrieve a certificate for the requested SNI", func() {
_, err := cc.SignServerProof("invalid", []byte("chlo"), []byte("scfg"))
Expect(err).To(MatchError(errNoMatchingCertificate))
})
It("signs the server config", func() {
config.Certificates = []tls.Certificate{cert}
proof, err := cc.SignServerProof("", []byte("chlo"), []byte("scfg"))
Expect(err).ToNot(HaveOccurred())
Expect(proof).ToNot(BeEmpty())
})
})
Context("retrieving certificates", func() {
It("errors without certificates", func() {
_, err := cc.getCertForSNI("")
Expect(err).To(MatchError(errNoMatchingCertificate))
})
It("uses first certificate in config.Certificates", func() {
config.Certificates = []tls.Certificate{cert}
cert, err := cc.getCertForSNI("")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses NameToCertificate entries", func() {
config.Certificates = []tls.Certificate{cert, cert} // two entries so the long path is used
config.NameToCertificate = map[string]*tls.Certificate{
"quic.clemente.io": &cert,
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses NameToCertificate entries with wildcard", func() {
config.Certificates = []tls.Certificate{cert, cert} // two entries so the long path is used
config.NameToCertificate = map[string]*tls.Certificate{
"*.clemente.io": &cert,
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("uses GetCertificate", func() {
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
Expect(clientHello.ServerName).To(Equal("quic.clemente.io"))
return &cert, nil
}
cert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
Expect(cert.PrivateKey).ToNot(BeNil())
Expect(cert.Certificate[0]).ToNot(BeNil())
})
It("gets leaf certificates", func() {
config.Certificates = []tls.Certificate{cert}
cert2, err := cc.GetLeafCert("")
Expect(err).ToNot(HaveOccurred())
Expect(cert2).To(Equal(cert.Certificate[0]))
})
It("errors when it can't retrieve a leaf certificate", func() {
_, err := cc.GetLeafCert("invalid domain")
Expect(err).To(MatchError(errNoMatchingCertificate))
})
It("respects GetConfigForClient", func() {
if !reflect.ValueOf(tls.Config{}).FieldByName("GetConfigForClient").IsValid() {
// Pre 1.8, we don't have to do anything
return
}
nestedConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
l := func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
Expect(chi.ServerName).To(Equal("quic.clemente.io"))
return nestedConfig, nil
}
reflect.ValueOf(config).Elem().FieldByName("GetConfigForClient").Set(reflect.ValueOf(l))
resultCert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).NotTo(HaveOccurred())
Expect(*resultCert).To(Equal(cert))
})
})
})

View file

@ -1,272 +0,0 @@
package crypto
import (
"bytes"
"compress/flate"
"compress/zlib"
"encoding/binary"
"errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type entryType uint8
const (
entryCompressed entryType = 1
entryCached entryType = 2
entryCommon entryType = 3
)
type entry struct {
t entryType
h uint64 // set hash
i uint32 // index
}
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
res := &bytes.Buffer{}
cachedHashes, err := splitHashes(pCachedHashes)
if err != nil {
return nil, err
}
setHashes, err := splitHashes(pCommonSetHashes)
if err != nil {
return nil, err
}
chainHashes := make([]uint64, len(chain))
for i := range chain {
chainHashes[i] = HashCert(chain[i])
}
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
totalUncompressedLen := 0
for i, e := range entries {
res.WriteByte(uint8(e.t))
switch e.t {
case entryCached:
utils.LittleEndian.WriteUint64(res, e.h)
case entryCommon:
utils.LittleEndian.WriteUint64(res, e.h)
utils.LittleEndian.WriteUint32(res, e.i)
case entryCompressed:
totalUncompressedLen += 4 + len(chain[i])
}
}
res.WriteByte(0) // end of list
if totalUncompressedLen > 0 {
gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain))
if err != nil {
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
}
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
for i, e := range entries {
if e.t != entryCompressed {
continue
}
lenCert := len(chain[i])
gz.Write([]byte{
byte(lenCert & 0xff),
byte((lenCert >> 8) & 0xff),
byte((lenCert >> 16) & 0xff),
byte((lenCert >> 24) & 0xff),
})
gz.Write(chain[i])
}
gz.Close()
}
return res.Bytes(), nil
}
func decompressChain(data []byte) ([][]byte, error) {
var chain [][]byte
var entries []entry
r := bytes.NewReader(data)
var numCerts int
var hasCompressedCerts bool
for {
entryTypeByte, err := r.ReadByte()
if entryTypeByte == 0 {
break
}
et := entryType(entryTypeByte)
if err != nil {
return nil, err
}
numCerts++
switch et {
case entryCached:
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
return nil, errors.New("unexpected cached certificate")
case entryCommon:
e := entry{t: entryCommon}
e.h, err = utils.LittleEndian.ReadUint64(r)
if err != nil {
return nil, err
}
e.i, err = utils.LittleEndian.ReadUint32(r)
if err != nil {
return nil, err
}
certSet, ok := certSets[e.h]
if !ok {
return nil, errors.New("unknown certSet")
}
if e.i >= uint32(len(certSet)) {
return nil, errors.New("certificate not found in certSet")
}
entries = append(entries, e)
chain = append(chain, certSet[e.i])
case entryCompressed:
hasCompressedCerts = true
entries = append(entries, entry{t: entryCompressed})
chain = append(chain, nil)
default:
return nil, errors.New("unknown entryType")
}
}
if numCerts == 0 {
return make([][]byte, 0), nil
}
if hasCompressedCerts {
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
if err != nil {
fmt.Println(4)
return nil, err
}
zlibDict := buildZlibDictForEntries(entries, chain)
gz, err := zlib.NewReaderDict(r, zlibDict)
if err != nil {
return nil, err
}
defer gz.Close()
var totalLength uint32
var certIndex int
for totalLength < uncompressedLength {
lenBytes := make([]byte, 4)
_, err := gz.Read(lenBytes)
if err != nil {
return nil, err
}
certLen := binary.LittleEndian.Uint32(lenBytes)
cert := make([]byte, certLen)
n, err := gz.Read(cert)
if uint32(n) != certLen && err != nil {
return nil, err
}
for {
if certIndex >= len(entries) {
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
}
if entries[certIndex].t == entryCompressed {
chain[certIndex] = cert
certIndex++
break
}
certIndex++
}
totalLength += 4 + certLen
}
}
return chain, nil
}
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
res := make([]entry, len(chain))
chainLoop:
for i := range chain {
// Check if hash is in cachedHashes
for j := range cachedHashes {
if chainHashes[i] == cachedHashes[j] {
res[i] = entry{t: entryCached, h: chainHashes[i]}
continue chainLoop
}
}
// Go through common sets and check if it's in there
for _, setHash := range setHashes {
set, ok := certSets[setHash]
if !ok {
// We don't have this set
continue
}
// We have this set, check if chain[i] is in the set
pos := set.findCertInSet(chain[i])
if pos >= 0 {
// Found
res[i] = entry{t: entryCommon, h: setHash, i: uint32(pos)}
continue chainLoop
}
}
res[i] = entry{t: entryCompressed}
}
return res
}
func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte {
var dict bytes.Buffer
// First the cached and common in reverse order
for i := len(entries) - 1; i >= 0; i-- {
if entries[i].t == entryCompressed {
continue
}
dict.Write(chain[i])
}
dict.Write(certDictZlib)
return dict.Bytes()
}
func splitHashes(hashes []byte) ([]uint64, error) {
if len(hashes)%8 != 0 {
return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes")
}
n := len(hashes) / 8
res := make([]uint64, n)
for i := 0; i < n; i++ {
res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8])
}
return res, nil
}
func getCommonCertificateHashes() []byte {
ccs := make([]byte, 8*len(certSets))
i := 0
for certSetHash := range certSets {
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
i++
}
return ccs
}
// HashCert calculates the FNV1a hash of a certificate
func HashCert(cert []byte) uint64 {
h := fnv.New64a()
h.Write(cert)
return h.Sum64()
}

View file

@ -1,294 +0,0 @@
package crypto
import (
"bytes"
"compress/flate"
"compress/zlib"
"encoding/binary"
"errors"
"hash/fnv"
"github.com/lucas-clemente/quic-go-certificates"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
func byteHash(d []byte) []byte {
h := fnv.New64a()
h.Write(d)
s := h.Sum64()
res := make([]byte, 8)
binary.LittleEndian.PutUint64(res, s)
return res
}
var _ = Describe("Cert compression and decompression", func() {
var certSetsOld map[uint64]certSet
BeforeEach(func() {
certSetsOld = make(map[uint64]certSet)
for s := range certSets {
certSetsOld[s] = certSets[s]
}
})
AfterEach(func() {
certSets = certSetsOld
})
It("compresses empty", func() {
compressed, err := compressChain(nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(compressed).To(Equal([]byte{0}))
})
It("decompresses empty", func() {
compressed, err := compressChain(nil, nil, nil)
Expect(err).ToNot(HaveOccurred())
uncompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(uncompressed).To(BeEmpty())
})
It("gives correct single cert", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert)
z.Close()
chain := [][]byte{cert}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(compressed).To(Equal(append([]byte{
0x01, 0x00,
0x08, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("decompresses a single cert", func() {
cert := []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}
chain := [][]byte{cert}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
uncompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(uncompressed).To(Equal(chain))
})
It("gives correct cert and intermediate", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert1)
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert2)
z.Close()
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(compressed).To(Equal(append([]byte{
0x01, 0x01, 0x00,
0x10, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("decompresses the chain with a cert and an intermediate", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
decompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(decompressed).To(Equal(chain))
})
It("uses cached certificates", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
certHash := byteHash(cert)
chain := [][]byte{cert}
compressed, err := compressChain(chain, nil, certHash)
Expect(err).ToNot(HaveOccurred())
expected := append([]byte{0x02}, certHash...)
expected = append(expected, 0x00)
Expect(compressed).To(Equal(expected))
})
It("uses cached certificates and compressed combined", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
cert2Hash := byteHash(cert2)
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert1)
z.Close()
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, nil, cert2Hash)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x01, 0x02}
expected = append(expected, cert2Hash...)
expected = append(expected, 0x00)
expected = append(expected, []byte{0x08, 0, 0, 0}...)
expected = append(expected, certZlib.Bytes()...)
Expect(compressed).To(Equal(expected))
})
It("uses common certificate sets", func() {
cert := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x03}
expected = append(expected, setHash...)
expected = append(expected, []byte{42, 0, 0, 0}...)
expected = append(expected, 0x00)
Expect(compressed).To(Equal(expected))
})
It("decompresses a single cert form a common certificate set", func() {
cert := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
decompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(decompressed).To(Equal(chain))
})
It("decompresses multiple certs form common certificate sets", func() {
cert1 := certsets.CertSet3[42]
cert2 := certsets.CertSet2[24]
setHash := make([]byte, 16)
binary.LittleEndian.PutUint64(setHash[0:8], certsets.CertSet3Hash)
binary.LittleEndian.PutUint64(setHash[8:16], certsets.CertSet2Hash)
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
decompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(decompressed).To(Equal(chain))
})
It("ignores uncommon certificate sets", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, 0xdeadbeef)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert)
z.Close()
Expect(compressed).To(Equal(append([]byte{
0x01, 0x00,
0x08, 0x00, 0x00, 0x00,
}, certZlib.Bytes()...)))
})
It("errors if a common set does not exist", func() {
cert := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
delete(certSets, certsets.CertSet3Hash)
_, err = decompressChain(compressed)
Expect(err).To(MatchError(errors.New("unknown certSet")))
})
It("errors if a cert in a common set does not exist", func() {
certSet := [][]byte{
{0x1, 0x2, 0x3, 0x4},
{0x5, 0x6, 0x7, 0x8},
}
certSets[0x1337] = certSet
cert := certSet[1]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, 0x1337)
chain := [][]byte{cert}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
certSets[0x1337] = certSet[:1] // delete the last certificate from the certSet
_, err = decompressChain(compressed)
Expect(err).To(MatchError(errors.New("certificate not found in certSet")))
})
It("uses common certificates and compressed combined", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
certZlib := &bytes.Buffer{}
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
Expect(err).ToNot(HaveOccurred())
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
z.Write(cert1)
z.Close()
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x01, 0x03}
expected = append(expected, setHash...)
expected = append(expected, []byte{42, 0, 0, 0}...)
expected = append(expected, 0x00)
expected = append(expected, []byte{0x08, 0, 0, 0}...)
expected = append(expected, certZlib.Bytes()...)
Expect(compressed).To(Equal(expected))
})
It("decompresses a certficate from a common set and a compressed cert combined", func() {
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
cert2 := certsets.CertSet3[42]
setHash := make([]byte, 8)
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, setHash, nil)
Expect(err).ToNot(HaveOccurred())
decompressed, err := decompressChain(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(decompressed).To(Equal(chain))
})
It("rejects invalid CCS / CCRT hashes", func() {
cert := []byte{0xde, 0xca, 0xfb, 0xad}
chain := [][]byte{cert}
_, err := compressChain(chain, []byte("foo"), nil)
Expect(err).To(MatchError("expected a multiple of 8 bytes for CCS / CCRT hashes"))
_, err = compressChain(chain, nil, []byte("foo"))
Expect(err).To(MatchError("expected a multiple of 8 bytes for CCS / CCRT hashes"))
})
Context("common certificate hashes", func() {
It("gets the hashes", func() {
ccs := getCommonCertificateHashes()
Expect(ccs).ToNot(BeEmpty())
hashes, err := splitHashes(ccs)
Expect(err).ToNot(HaveOccurred())
for _, hash := range hashes {
Expect(certSets).To(HaveKey(hash))
}
})
It("returns an empty slice if there are not common sets", func() {
certSets = make(map[uint64]certSet)
ccs := getCommonCertificateHashes()
Expect(ccs).ToNot(BeNil())
Expect(ccs).To(HaveLen(0))
})
})
})

View file

@ -1,128 +0,0 @@
package crypto
var certDictZlib = []byte{
0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04,
0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03,
0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30,
0x5f, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x04, 0x01,
0x06, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, 0xfd, 0x6d, 0x01, 0x07,
0x17, 0x01, 0x30, 0x33, 0x20, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65,
0x64, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x20, 0x53, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x34,
0x20, 0x53, 0x53, 0x4c, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31,
0x32, 0x20, 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x20, 0x43, 0x41, 0x30, 0x2d, 0x61, 0x69, 0x61, 0x2e,
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x45, 0x2d, 0x63, 0x72, 0x6c, 0x2e, 0x76, 0x65, 0x72, 0x69, 0x73,
0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x45, 0x2e, 0x63, 0x65,
0x72, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01,
0x01, 0x05, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x4a, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73,
0x2f, 0x63, 0x70, 0x73, 0x20, 0x28, 0x63, 0x29, 0x30, 0x30, 0x09, 0x06,
0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x7b, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x0e, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86,
0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01,
0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xd2,
0x6f, 0x64, 0x6f, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x43, 0x2e,
0x63, 0x72, 0x6c, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
0x04, 0x14, 0xb4, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69,
0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x30, 0x0b, 0x06, 0x03,
0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x30, 0x0d, 0x06, 0x09,
0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30,
0x81, 0xca, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x55, 0x53, 0x31, 0x10, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x08,
0x13, 0x07, 0x41, 0x72, 0x69, 0x7a, 0x6f, 0x6e, 0x61, 0x31, 0x13, 0x30,
0x11, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0a, 0x53, 0x63, 0x6f, 0x74,
0x74, 0x73, 0x64, 0x61, 0x6c, 0x65, 0x31, 0x1a, 0x30, 0x18, 0x06, 0x03,
0x55, 0x04, 0x0a, 0x13, 0x11, 0x47, 0x6f, 0x44, 0x61, 0x64, 0x64, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, 0x33,
0x30, 0x31, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x2a, 0x68, 0x74, 0x74,
0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63,
0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74,
0x6f, 0x72, 0x79, 0x31, 0x30, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x04, 0x03,
0x13, 0x27, 0x47, 0x6f, 0x20, 0x44, 0x61, 0x64, 0x64, 0x79, 0x20, 0x53,
0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66,
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68,
0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55,
0x04, 0x05, 0x13, 0x08, 0x30, 0x37, 0x39, 0x36, 0x39, 0x32, 0x38, 0x37,
0x30, 0x1e, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d,
0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x0c,
0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x02, 0x30, 0x00,
0x30, 0x1d, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff,
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05,
0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07,
0x03, 0x02, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff,
0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x33, 0x06, 0x03, 0x55, 0x1d,
0x1f, 0x04, 0x2c, 0x30, 0x2a, 0x30, 0x28, 0xa0, 0x26, 0xa0, 0x24, 0x86,
0x22, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e,
0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x67, 0x64, 0x73, 0x31, 0x2d, 0x32, 0x30, 0x2a, 0x30, 0x28, 0x06, 0x08,
0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x1c, 0x68, 0x74,
0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, 0x65,
0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
0x70, 0x73, 0x30, 0x34, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x5a, 0x17,
0x0d, 0x31, 0x33, 0x30, 0x35, 0x30, 0x39, 0x06, 0x08, 0x2b, 0x06, 0x01,
0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x3a,
0x2f, 0x2f, 0x73, 0x30, 0x39, 0x30, 0x37, 0x06, 0x08, 0x2b, 0x06, 0x01,
0x05, 0x05, 0x07, 0x02, 0x30, 0x44, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04,
0x3d, 0x30, 0x3b, 0x30, 0x39, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86,
0xf8, 0x45, 0x01, 0x07, 0x17, 0x06, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
0x55, 0x04, 0x06, 0x13, 0x02, 0x47, 0x42, 0x31, 0x1b, 0x53, 0x31, 0x17,
0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0e, 0x56, 0x65, 0x72,
0x69, 0x53, 0x69, 0x67, 0x6e, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31,
0x1f, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x16, 0x56, 0x65,
0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x54, 0x72, 0x75, 0x73, 0x74,
0x20, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x31, 0x3b, 0x30, 0x39,
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x32, 0x54, 0x65, 0x72, 0x6d, 0x73,
0x20, 0x6f, 0x66, 0x20, 0x75, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20, 0x68,
0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76,
0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x72, 0x70, 0x61, 0x20, 0x28, 0x63, 0x29, 0x30, 0x31, 0x10, 0x30, 0x0e,
0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x07, 0x53, 0x31, 0x13, 0x30, 0x11,
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0a, 0x47, 0x31, 0x13, 0x30, 0x11,
0x06, 0x0b, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x3c, 0x02, 0x01,
0x03, 0x13, 0x02, 0x55, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04,
0x03, 0x14, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13,
0x31, 0x1d, 0x30, 0x1b, 0x06, 0x03, 0x55, 0x04, 0x0f, 0x13, 0x14, 0x50,
0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e,
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x31, 0x12, 0x31, 0x21, 0x30,
0x1f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x18, 0x44, 0x6f, 0x6d, 0x61,
0x69, 0x6e, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x20, 0x56,
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x64, 0x31, 0x14, 0x31, 0x31,
0x30, 0x2f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x28, 0x53, 0x65, 0x65,
0x20, 0x77, 0x77, 0x77, 0x2e, 0x72, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63,
0x75, 0x72, 0x65, 0x2e, 0x67, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53,
0x69, 0x67, 0x6e, 0x31, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x41,
0x2e, 0x63, 0x72, 0x6c, 0x56, 0x65, 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e,
0x20, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x33, 0x20, 0x45, 0x63, 0x72,
0x6c, 0x2e, 0x67, 0x65, 0x6f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x63, 0x72, 0x6c, 0x73, 0x2f, 0x73, 0x64, 0x31, 0x1a,
0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x3a,
0x2f, 0x2f, 0x45, 0x56, 0x49, 0x6e, 0x74, 0x6c, 0x2d, 0x63, 0x63, 0x72,
0x74, 0x2e, 0x67, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x69, 0x63, 0x65, 0x72,
0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x31, 0x6f, 0x63, 0x73, 0x70, 0x2e,
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
0x30, 0x39, 0x72, 0x61, 0x70, 0x69, 0x64, 0x73, 0x73, 0x6c, 0x2e, 0x63,
0x6f, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72,
0x79, 0x2f, 0x30, 0x81, 0x80, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05,
0x07, 0x01, 0x01, 0x04, 0x74, 0x30, 0x72, 0x30, 0x24, 0x06, 0x08, 0x2b,
0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x18, 0x68, 0x74, 0x74,
0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x2e, 0x67, 0x6f, 0x64,
0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x4a, 0x06,
0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x3e, 0x68,
0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66,
0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64,
0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73,
0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x67, 0x64, 0x5f, 0x69, 0x6e, 0x74,
0x65, 0x72, 0x6d, 0x65, 0x64, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x72,
0x74, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16,
0x80, 0x14, 0xfd, 0xac, 0x61, 0x32, 0x93, 0x6c, 0x45, 0xd6, 0xe2, 0xee,
0x85, 0x5f, 0x9a, 0xba, 0xe7, 0x76, 0x99, 0x68, 0xcc, 0xe7, 0x30, 0x27,
0x86, 0x29, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x86, 0x30,
0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73,
}

View file

@ -1,135 +0,0 @@
package crypto
import (
"crypto/tls"
"crypto/x509"
"errors"
"hash/fnv"
"time"
"github.com/lucas-clemente/quic-go/qerr"
)
// CertManager manages the certificates sent by the server
type CertManager interface {
SetData([]byte) error
GetCommonCertificateHashes() []byte
GetLeafCert() []byte
GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error
GetChain() []*x509.Certificate
}
type certManager struct {
chain []*x509.Certificate
config *tls.Config
}
var _ CertManager = &certManager{}
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
// NewCertManager creates a new CertManager
func NewCertManager(tlsConfig *tls.Config) CertManager {
return &certManager{config: tlsConfig}
}
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
func (c *certManager) SetData(data []byte) error {
byteChain, err := decompressChain(data)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
chain := make([]*x509.Certificate, len(byteChain))
for i, data := range byteChain {
cert, err := x509.ParseCertificate(data)
if err != nil {
return err
}
chain[i] = cert
}
c.chain = chain
return nil
}
func (c *certManager) GetChain() []*x509.Certificate {
return c.chain
}
func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes()
}
// GetLeafCert returns the leaf certificate of the certificate chain
// it returns nil if the certificate chain has not yet been set
func (c *certManager) GetLeafCert() []byte {
if len(c.chain) == 0 {
return nil
}
return c.chain[0].Raw
}
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
func (c *certManager) GetLeafCertHash() (uint64, error) {
leafCert := c.GetLeafCert()
if leafCert == nil {
return 0, errNoCertificateChain
}
h := fnv.New64a()
_, err := h.Write(leafCert)
if err != nil {
return 0, err
}
return h.Sum64(), nil
}
// VerifyServerProof verifies the signature of the server config
// it should only be called after the certificate chain has been set, otherwise it returns false
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
if len(c.chain) == 0 {
return false
}
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
}
// Verify verifies the certificate chain
func (c *certManager) Verify(hostname string) error {
if len(c.chain) == 0 {
return errNoCertificateChain
}
if c.config != nil && c.config.InsecureSkipVerify {
return nil
}
leafCert := c.chain[0]
var opts x509.VerifyOptions
if c.config != nil {
opts.Roots = c.config.RootCAs
if c.config.Time == nil {
opts.CurrentTime = time.Now()
} else {
opts.CurrentTime = c.config.Time()
}
}
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
opts.DNSName = hostname
// the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 {
intermediates := x509.NewCertPool()
for i := 1; i < len(c.chain); i++ {
intermediates.AddCert(c.chain[i])
}
opts.Intermediates = intermediates
}
_, err := leafCert.Verify(opts)
return err
}

View file

@ -1,348 +0,0 @@
package crypto
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"math/big"
"runtime"
"time"
"github.com/lucas-clemente/quic-go/internal/testdata"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Cert Manager", func() {
var cm *certManager
var key1, key2 *rsa.PrivateKey
var cert1, cert2 []byte
BeforeEach(func() {
var err error
cm = NewCertManager(nil).(*certManager)
key1, err = rsa.GenerateKey(rand.Reader, 768)
Expect(err).ToNot(HaveOccurred())
key2, err = rsa.GenerateKey(rand.Reader, 768)
Expect(err).ToNot(HaveOccurred())
template := &x509.Certificate{SerialNumber: big.NewInt(1)}
cert1, err = x509.CreateCertificate(rand.Reader, template, template, &key1.PublicKey, key1)
Expect(err).ToNot(HaveOccurred())
cert2, err = x509.CreateCertificate(rand.Reader, template, template, &key2.PublicKey, key2)
Expect(err).ToNot(HaveOccurred())
})
It("saves a client TLS config", func() {
tlsConf := &tls.Config{ServerName: "quic.clemente.io"}
cm = NewCertManager(tlsConf).(*certManager)
Expect(cm.config.ServerName).To(Equal("quic.clemente.io"))
})
It("errors when given invalid data", func() {
err := cm.SetData([]byte("foobar"))
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")))
})
It("gets the common certificate hashes", func() {
ccs := cm.GetCommonCertificateHashes()
Expect(ccs).ToNot(BeEmpty())
})
Context("setting the data", func() {
It("decompresses a certificate chain", func() {
chain := [][]byte{cert1, cert2}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
err = cm.SetData(compressed)
Expect(err).ToNot(HaveOccurred())
Expect(cm.chain[0].Raw).To(Equal(cert1))
Expect(cm.chain[1].Raw).To(Equal(cert2))
})
It("errors if it can't decompress the chain", func() {
err := cm.SetData([]byte("invalid data"))
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")))
})
It("errors if it can't parse a certificate", func() {
chain := [][]byte{[]byte("cert1"), []byte("cert2")}
compressed, err := compressChain(chain, nil, nil)
Expect(err).ToNot(HaveOccurred())
err = cm.SetData(compressed)
_, ok := err.(asn1.StructuralError)
Expect(ok).To(BeTrue())
})
})
Context("getting the leaf cert", func() {
It("gets it", func() {
xcert1, err := x509.ParseCertificate(cert1)
Expect(err).ToNot(HaveOccurred())
xcert2, err := x509.ParseCertificate(cert2)
Expect(err).ToNot(HaveOccurred())
cm.chain = []*x509.Certificate{xcert1, xcert2}
leafCert := cm.GetLeafCert()
Expect(leafCert).To(Equal(cert1))
})
It("returns nil if the chain hasn't been set yet", func() {
leafCert := cm.GetLeafCert()
Expect(leafCert).To(BeNil())
})
})
Context("getting the leaf cert hash", func() {
It("calculates the FVN1a 64 hash", func() {
cm.chain = make([]*x509.Certificate, 1)
cm.chain[0] = &x509.Certificate{
Raw: []byte("test fnv hash"),
}
hash, err := cm.GetLeafCertHash()
Expect(err).ToNot(HaveOccurred())
// hash calculated on http://www.nitrxgen.net/hashgen/
Expect(hash).To(Equal(uint64(0x4770f6141fa0f5ad)))
})
It("errors if the certificate chain is not loaded", func() {
_, err := cm.GetLeafCertHash()
Expect(err).To(MatchError(errNoCertificateChain))
})
})
Context("verifying the server config signature", func() {
It("returns false when the chain hasn't been set yet", func() {
valid := cm.VerifyServerProof([]byte("proof"), []byte("chlo"), []byte("scfg"))
Expect(valid).To(BeFalse())
})
It("verifies the signature", func() {
chlo := []byte("client hello")
scfg := []byte("server config data")
xcert1, err := x509.ParseCertificate(cert1)
Expect(err).ToNot(HaveOccurred())
cm.chain = []*x509.Certificate{xcert1}
proof, err := signServerProof(&tls.Certificate{PrivateKey: key1}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
valid := cm.VerifyServerProof(proof, chlo, scfg)
Expect(valid).To(BeTrue())
})
It("rejects an invalid signature", func() {
xcert1, err := x509.ParseCertificate(cert1)
Expect(err).ToNot(HaveOccurred())
cm.chain = []*x509.Certificate{xcert1}
valid := cm.VerifyServerProof([]byte("invalid proof"), []byte("chlo"), []byte("scfg"))
Expect(valid).To(BeFalse())
})
})
Context("verifying the certificate chain", func() {
generateCertificate := func(template, parent *x509.Certificate, pubKey *rsa.PublicKey, privKey *rsa.PrivateKey) *x509.Certificate {
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pubKey, privKey)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return cert
}
getCertificate := func(template *x509.Certificate) (*rsa.PrivateKey, *x509.Certificate) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
return key, generateCertificate(template, template, &key.PublicKey, key)
}
It("accepts a valid certificate", func() {
cc := NewCertChain(testdata.GetTLSConfig()).(*certChain)
tlsCert, err := cc.getCertForSNI("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
for _, data := range tlsCert.Certificate {
var cert *x509.Certificate
cert, err = x509.ParseCertificate(data)
Expect(err).ToNot(HaveOccurred())
cm.chain = append(cm.chain, cert)
}
err = cm.Verify("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
})
It("doesn't accept an expired certificate", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-25 * time.Hour),
NotAfter: time.Now().Add(-time.Hour),
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
err := cm.Verify("")
Expect(err).To(HaveOccurred())
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
})
It("doesn't accept a certificate that is not yet valid", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(time.Hour),
NotAfter: time.Now().Add(25 * time.Hour),
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
err := cm.Verify("")
Expect(err).To(HaveOccurred())
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
})
It("doesn't accept an certificate for the wrong hostname", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
Subject: pkix.Name{CommonName: "google.com"},
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
err := cm.Verify("quic.clemente.io")
Expect(err).To(HaveOccurred())
_, ok := err.(x509.HostnameError)
Expect(ok).To(BeTrue())
})
It("errors if the chain hasn't been set yet", func() {
err := cm.Verify("example.com")
Expect(err).To(HaveOccurred())
})
// this tests relies on LetsEncrypt not being contained in the Root CAs
It("rejects valid certificate with missing certificate chain", func() {
if runtime.GOOS == "windows" {
Skip("LetsEncrypt Root CA is included in Windows")
}
cert := testdata.GetCertificate()
xcert, err := x509.ParseCertificate(cert.Certificate[0])
Expect(err).ToNot(HaveOccurred())
cm.chain = []*x509.Certificate{xcert}
err = cm.Verify("quic.clemente.io")
_, ok := err.(x509.UnknownAuthorityError)
Expect(ok).To(BeTrue())
})
It("doesn't do any certificate verification if InsecureSkipVerify is set", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
}
_, leafCert := getCertificate(template)
cm.config = &tls.Config{
InsecureSkipVerify: true,
}
cm.chain = []*x509.Certificate{leafCert}
err := cm.Verify("quic.clemente.io")
Expect(err).ToNot(HaveOccurred())
})
It("uses the time specified in a client TLS config", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-25 * time.Hour),
NotAfter: time.Now().Add(-23 * time.Hour),
Subject: pkix.Name{CommonName: "quic.clemente.io"},
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
cm.config = &tls.Config{
Time: func() time.Time { return time.Now().Add(-24 * time.Hour) },
}
err := cm.Verify("quic.clemente.io")
_, ok := err.(x509.UnknownAuthorityError)
Expect(ok).To(BeTrue())
})
It("rejects certificates that are expired at the time specified in a client TLS config", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
}
_, leafCert := getCertificate(template)
cm.chain = []*x509.Certificate{leafCert}
cm.config = &tls.Config{
Time: func() time.Time { return time.Now().Add(-24 * time.Hour) },
}
err := cm.Verify("quic.clemente.io")
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
})
It("uses the Root CA given in the client config", func() {
if runtime.GOOS == "windows" {
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
Skip("windows")
}
templateRoot := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
BasicConstraintsValid: true,
}
rootKey, rootCert := getCertificate(templateRoot)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
Subject: pkix.Name{CommonName: "google.com"},
}
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).ToNot(HaveOccurred())
leafCert := generateCertificate(template, rootCert, &key.PublicKey, rootKey)
rootCAPool := x509.NewCertPool()
rootCAPool.AddCert(rootCert)
cm.chain = []*x509.Certificate{leafCert}
cm.config = &tls.Config{
RootCAs: rootCAPool,
}
err = cm.Verify("google.com")
Expect(err).ToNot(HaveOccurred())
})
})
})

View file

@ -1,24 +0,0 @@
package crypto
import (
"bytes"
"github.com/lucas-clemente/quic-go-certificates"
)
type certSet [][]byte
var certSets = map[uint64]certSet{
certsets.CertSet2Hash: certsets.CertSet2,
certsets.CertSet3Hash: certsets.CertSet3,
}
// findCertInSet searches for the cert in the set. Negative return value means not found.
func (s *certSet) findCertInSet(cert []byte) int {
for i, c := range *s {
if bytes.Equal(c, cert) {
return i
}
}
return -1
}

View file

@ -1,61 +0,0 @@
// +build ignore
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/aead/chacha20"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadChacha20Poly1305 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
}
// copy because ChaCha20Poly1305 expects array pointers
var MyKey, OtherKey [32]byte
copy(MyKey[:], myKey)
copy(OtherKey[:], otherKey)
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
if err != nil {
return nil, err
}
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
if err != nil {
return nil, err
}
return &aeadChacha20Poly1305{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}

View file

@ -1,71 +0,0 @@
// +build ignore
package crypto
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Chacha20poly1305", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
keyAlice = make([]byte, 32)
keyBob = make([]byte, 32)
ivAlice = make([]byte, 4)
ivBob = make([]byte, 4)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + 12))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
var err error
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})

View file

@ -1,41 +0,0 @@
package crypto
import (
"crypto/rand"
"errors"
"golang.org/x/crypto/curve25519"
)
// KeyExchange manages the exchange of keys
type curve25519KEX struct {
secret [32]byte
public [32]byte
}
var _ KeyExchange = &curve25519KEX{}
// NewCurve25519KEX creates a new KeyExchange using Curve25519, see https://cr.yp.to/ecdh.html
func NewCurve25519KEX() (KeyExchange, error) {
c := &curve25519KEX{}
if _, err := rand.Read(c.secret[:]); err != nil {
return nil, errors.New("Curve25519: could not create private key")
}
curve25519.ScalarBaseMult(&c.public, &c.secret)
return c, nil
}
func (c *curve25519KEX) PublicKey() []byte {
return c.public[:]
}
func (c *curve25519KEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
if len(otherPublic) != 32 {
return nil, errors.New("Curve25519: expected public key of 32 byte")
}
var res [32]byte
var otherPublicArray [32]byte
copy(otherPublicArray[:], otherPublic)
curve25519.ScalarMult(&res, &c.secret, &otherPublicArray)
return res[:], nil
}

View file

@ -1,27 +0,0 @@
package crypto
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("ProofRsa", func() {
It("works", func() {
a, err := NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
b, err := NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
sA, err := a.CalculateSharedKey(b.PublicKey())
Expect(err).ToNot(HaveOccurred())
sB, err := b.CalculateSharedKey(a.PublicKey())
Expect(err).ToNot(HaveOccurred())
Expect(sA).To(Equal(sB))
})
It("rejects short public keys", func() {
a, err := NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
_, err = a.CalculateSharedKey(nil)
Expect(err).To(MatchError("Curve25519: expected public key of 32 byte"))
})
})

View file

@ -1,100 +0,0 @@
package crypto
import (
"bytes"
"crypto/sha256"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"golang.org/x/crypto/hkdf"
)
// DeriveKeysChacha20 derives the client and server keys and creates a matching chacha20poly1305 AEAD instance
// func DeriveKeysChacha20(version protocol.VersionNumber, forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) {
// otherKey, myKey, otherIV, myIV, err := deriveKeys(version, forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 32)
// if err != nil {
// return nil, err
// }
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
// }
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
var swap bool
if pers == protocol.PerspectiveClient {
swap = true
}
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap)
if err != nil {
return nil, err
}
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
}
// deriveKeys derives the keys and the IVs
// swap should be set true if generating the values for the client, and false for the server
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) {
var info bytes.Buffer
if forwardSecure {
info.Write([]byte("QUIC forward secure key expansion\x00"))
} else {
info.Write([]byte("QUIC key expansion\x00"))
}
info.Write(connID)
info.Write(chlo)
info.Write(scfg)
info.Write(cert)
r := hkdf.New(sha256.New, sharedSecret, nonces, info.Bytes())
s := make([]byte, 2*keyLen+2*4)
if _, err := io.ReadFull(r, s); err != nil {
return nil, nil, nil, nil, err
}
key1 := s[:keyLen]
key2 := s[keyLen : 2*keyLen]
iv1 := s[2*keyLen : 2*keyLen+4]
iv2 := s[2*keyLen+4:]
var otherKey, myKey []byte
var otherIV, myIV []byte
if !forwardSecure {
if err := diversify(key2, iv2, divNonce); err != nil {
return nil, nil, nil, nil, err
}
}
if swap {
otherKey = key2
myKey = key1
otherIV = iv2
myIV = iv1
} else {
otherKey = key1
myKey = key2
otherIV = iv1
myIV = iv2
}
return otherKey, myKey, otherIV, myIV, nil
}
func diversify(key, iv, divNonce []byte) error {
secret := make([]byte, len(key)+len(iv))
copy(secret, key)
copy(secret[len(key):], iv)
r := hkdf.New(sha256.New, secret, divNonce, []byte("QUIC key diversification"))
if _, err := io.ReadFull(r, key); err != nil {
return err
}
if _, err := io.ReadFull(r, iv); err != nil {
return err
}
return nil
}

View file

@ -1,197 +0,0 @@
package crypto
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("QUIC Crypto Key Derivation", func() {
// Context("chacha20poly1305", func() {
// It("derives non-fs keys", func() {
// aead, err := DeriveKeysChacha20(
// protocol.Version32,
// false,
// []byte("0123456789012345678901"),
// []byte("nonce"),
// protocol.ConnectionID(42),
// []byte("chlo"),
// []byte("scfg"),
// []byte("cert"),
// nil,
// )
// Expect(err).ToNot(HaveOccurred())
// chacha := aead.(*aeadChacha20Poly1305)
// // If the IVs match, the keys will match too, since the keys are read earlier
// Expect(chacha.myIV).To(Equal([]byte{0xf0, 0xf5, 0x4c, 0xa8}))
// Expect(chacha.otherIV).To(Equal([]byte{0x75, 0xd8, 0xa2, 0x8d}))
// })
//
// It("derives fs keys", func() {
// aead, err := DeriveKeysChacha20(
// protocol.Version32,
// true,
// []byte("0123456789012345678901"),
// []byte("nonce"),
// protocol.ConnectionID(42),
// []byte("chlo"),
// []byte("scfg"),
// []byte("cert"),
// nil,
// )
// Expect(err).ToNot(HaveOccurred())
// chacha := aead.(*aeadChacha20Poly1305)
// // If the IVs match, the keys will match too, since the keys are read earlier
// Expect(chacha.myIV).To(Equal([]byte{0xf5, 0x73, 0x11, 0x79}))
// Expect(chacha.otherIV).To(Equal([]byte{0xf7, 0x26, 0x4d, 0x2c}))
// })
//
// It("does not use diversification nonces in FS key derivation", func() {
// aead, err := DeriveKeysChacha20(
// protocol.Version33,
// true,
// []byte("0123456789012345678901"),
// []byte("nonce"),
// protocol.ConnectionID(42),
// []byte("chlo"),
// []byte("scfg"),
// []byte("cert"),
// []byte("divnonce"),
// )
// Expect(err).ToNot(HaveOccurred())
// chacha := aead.(*aeadChacha20Poly1305)
// // If the IVs match, the keys will match too, since the keys are read earlier
// Expect(chacha.myIV).To(Equal([]byte{0xf5, 0x73, 0x11, 0x79}))
// Expect(chacha.otherIV).To(Equal([]byte{0xf7, 0x26, 0x4d, 0x2c}))
// })
//
// It("uses diversification nonces in initial key derivation", func() {
// aead, err := DeriveKeysChacha20(
// protocol.Version33,
// false,
// []byte("0123456789012345678901"),
// []byte("nonce"),
// protocol.ConnectionID(42),
// []byte("chlo"),
// []byte("scfg"),
// []byte("cert"),
// []byte("divnonce"),
// )
// Expect(err).ToNot(HaveOccurred())
// chacha := aead.(*aeadChacha20Poly1305)
// // If the IVs match, the keys will match too, since the keys are read earlier
// Expect(chacha.myIV).To(Equal([]byte{0xc4, 0x12, 0x25, 0x64}))
// Expect(chacha.otherIV).To(Equal([]byte{0x75, 0xd8, 0xa2, 0x8d}))
// })
// })
Context("AES-GCM", func() {
It("derives non-forward secure keys", func() {
aead, err := DeriveQuicCryptoAESKeys(
false,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("divnonce"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aesgcm := aead.(*aeadAESGCM12)
// If the IVs match, the keys will match too, since the keys are read earlier
Expect(aesgcm.myIV).To(Equal([]byte{0x1c, 0xec, 0xac, 0x9b}))
Expect(aesgcm.otherIV).To(Equal([]byte{0x64, 0xef, 0x3c, 0x9}))
})
It("uses the diversification nonce when generating non-forwared secure keys", func() {
aead1, err := DeriveQuicCryptoAESKeys(
false,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("divnonce"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aead2, err := DeriveQuicCryptoAESKeys(
false,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("ecnonvid"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aesgcm1 := aead1.(*aeadAESGCM12)
aesgcm2 := aead2.(*aeadAESGCM12)
Expect(aesgcm1.myIV).ToNot(Equal(aesgcm2.myIV))
Expect(aesgcm1.otherIV).To(Equal(aesgcm2.otherIV))
})
It("derives non-forward secure keys, for the other side", func() {
aead, err := DeriveQuicCryptoAESKeys(
false,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("divnonce"),
protocol.PerspectiveClient,
)
Expect(err).ToNot(HaveOccurred())
aesgcm := aead.(*aeadAESGCM12)
// If the IVs match, the keys will match too, since the keys are read earlier
Expect(aesgcm.otherIV).To(Equal([]byte{0x1c, 0xec, 0xac, 0x9b}))
Expect(aesgcm.myIV).To(Equal([]byte{0x64, 0xef, 0x3c, 0x9}))
})
It("derives forward secure keys", func() {
aead, err := DeriveQuicCryptoAESKeys(
true,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
nil,
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aesgcm := aead.(*aeadAESGCM12)
// If the IVs match, the keys will match too, since the keys are read earlier
Expect(aesgcm.myIV).To(Equal([]byte{0x7, 0xad, 0xab, 0xb8}))
Expect(aesgcm.otherIV).To(Equal([]byte{0xf2, 0x7a, 0xcc, 0x42}))
})
It("does not use div-nonce for FS key derivation", func() {
aead, err := DeriveQuicCryptoAESKeys(
true,
[]byte("0123456789012345678901"),
[]byte("nonce"),
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
[]byte("chlo"),
[]byte("scfg"),
[]byte("cert"),
[]byte("divnonce"),
protocol.PerspectiveServer,
)
Expect(err).ToNot(HaveOccurred())
aesgcm := aead.(*aeadAESGCM12)
// If the IVs match, the keys will match too, since the keys are read earlier
Expect(aesgcm.myIV).To(Equal([]byte{0x7, 0xad, 0xab, 0xb8}))
Expect(aesgcm.otherIV).To(Equal([]byte{0xf2, 0x7a, 0xcc, 0x42}))
})
})
})

View file

@ -1,7 +0,0 @@
package crypto
// KeyExchange manages the exchange of keys
type KeyExchange interface {
PublicKey() []byte
CalculateSharedKey(otherPublic []byte) ([]byte, error)
}

View file

@ -1,11 +0,0 @@
package crypto
import "github.com/lucas-clemente/quic-go/internal/protocol"
// NewNullAEAD creates a NullAEAD
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
if v.UsesTLS() {
return newNullAEADAESGCM(connID, p)
}
return &nullAEADFNV128a{perspective: p}, nil
}

View file

@ -8,7 +8,8 @@ import (
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
// NewNullAEAD creates a NullAEAD
func NewNullAEAD(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
clientSecret, serverSecret := computeSecrets(connectionID)
var mySecret, otherSecret []byte

View file

@ -56,9 +56,9 @@ var _ = Describe("NullAEAD using AES-GCM", func() {
It("seals and opens", func() {
connectionID := protocol.ConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef})
clientAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveClient)
clientAEAD, err := NewNullAEAD(connectionID, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
serverAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveServer)
serverAEAD, err := NewNullAEAD(connectionID, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad"))
@ -74,9 +74,9 @@ var _ = Describe("NullAEAD using AES-GCM", func() {
It("doesn't work if initialized with different connection IDs", func() {
c1 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1})
c2 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2})
clientAEAD, err := newNullAEADAESGCM(c1, protocol.PerspectiveClient)
clientAEAD, err := NewNullAEAD(c1, protocol.PerspectiveClient)
Expect(err).ToNot(HaveOccurred())
serverAEAD, err := newNullAEADAESGCM(c2, protocol.PerspectiveServer)
serverAEAD, err := NewNullAEAD(c2, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad"))

View file

@ -1,79 +0,0 @@
package crypto
import (
"bytes"
"errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// nullAEAD handles not-yet encrypted packets
type nullAEADFNV128a struct {
perspective protocol.Perspective
}
var _ AEAD = &nullAEADFNV128a{}
// Open and verify the ciphertext
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
if len(src) < 12 {
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
}
hash := fnv.New128a()
hash.Write(associatedData)
hash.Write(src[12:])
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Client"))
} else {
hash.Write([]byte("Server"))
}
sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
if !bytes.Equal(sum[:12], src[:12]) {
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
}
return src[12:], nil
}
// Seal writes hash and ciphertext to the buffer
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
if cap(dst) < 12+len(src) {
dst = make([]byte, 12+len(src))
} else {
dst = dst[:12+len(src)]
}
hash := fnv.New128a()
hash.Write(associatedData)
hash.Write(src)
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Server"))
} else {
hash.Write([]byte("Client"))
}
sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
copy(dst[12:], src)
copy(dst, sum[:12])
return dst
}
func (n *nullAEADFNV128a) Overhead() int {
return 12
}
func reverse(a []byte) {
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
a[left], a[right] = a[right], a[left]
}
}

View file

@ -1,55 +0,0 @@
package crypto
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("NullAEAD using FNV128a", func() {
aad := []byte("All human beings are born free and equal in dignity and rights.")
plainText := []byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")
hash36 := []byte{0x98, 0x9b, 0x33, 0x3f, 0xe8, 0xde, 0x32, 0x5c, 0xa6, 0x7f, 0x9c, 0xf7}
var aeadServer AEAD
var aeadClient AEAD
BeforeEach(func() {
aeadServer = &nullAEADFNV128a{protocol.PerspectiveServer}
aeadClient = &nullAEADFNV128a{protocol.PerspectiveClient}
})
It("seals and opens, client => server", func() {
cipherText := aeadClient.Seal(nil, plainText, 0, aad)
res, err := aeadServer.Open(nil, cipherText, 0, aad)
Expect(err).ToNot(HaveOccurred())
Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")))
})
It("seals and opens, server => client", func() {
cipherText := aeadServer.Seal(nil, plainText, 0, aad)
res, err := aeadClient.Open(nil, cipherText, 0, aad)
Expect(err).ToNot(HaveOccurred())
Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")))
})
It("rejects short ciphertexts", func() {
_, err := aeadServer.Open(nil, nil, 0, nil)
Expect(err).To(MatchError("NullAEAD: ciphertext cannot be less than 12 bytes long"))
})
It("seals in-place", func() {
buf := make([]byte, 6, 12+6)
copy(buf, []byte("foobar"))
res := aeadServer.Seal(buf[0:0], buf, 0, nil)
buf = buf[:12+6]
Expect(buf[12:]).To(Equal([]byte("foobar")))
Expect(res[12:]).To(Equal([]byte("foobar")))
})
It("fails", func() {
cipherText := append(append(hash36, plainText...), byte(0x42))
_, err := aeadClient.Open(nil, cipherText, 0, aad)
Expect(err).To(HaveOccurred())
})
})

View file

@ -1,17 +0,0 @@
package crypto
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("NullAEAD", func() {
It("selects the right FVN variant", func() {
connID := protocol.ConnectionID([]byte{0x42, 0, 0, 0, 0, 0, 0, 0})
Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.Version39)).To(Equal(&nullAEADFNV128a{
perspective: protocol.PerspectiveClient,
}))
Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.VersionTLS)).To(BeAssignableToTypeOf(&aeadAESGCM{}))
})
})

View file

@ -1,66 +0,0 @@
package crypto
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"errors"
"math/big"
)
type ecdsaSignature struct {
R, S *big.Int
}
// signServerProof signs CHLO and server config for use in the server proof
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
key, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
}
opts := crypto.SignerOpts(crypto.SHA256)
if _, ok = key.(*rsa.PrivateKey); ok {
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
}
return key.Sign(rand.Reader, hash.Sum(nil), opts)
}
// verifyServerProof verifies the server proof signature
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
// RSA
if cert.PublicKeyAlgorithm == x509.RSA {
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts)
return err == nil
}
// ECDSA
signature := &ecdsaSignature{}
rest, err := asn1.Unmarshal(proof, signature)
if err != nil || len(rest) != 0 {
return false
}
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S)
}

View file

@ -1,127 +0,0 @@
package crypto
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"math/big"
"github.com/lucas-clemente/quic-go/internal/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Proof", func() {
It("gives valid signatures with the key in internal/testdata", func() {
key := &testdata.GetTLSConfig().Certificates[0]
signature, err := signServerProof(key, []byte{'C', 'H', 'L', 'O'}, []byte{'S', 'C', 'F', 'G'})
Expect(err).ToNot(HaveOccurred())
// Generated with:
// ruby -e 'require "digest"; p Digest::SHA256.digest("QUIC CHLO and server config signature\x00" + "\x20\x00\x00\x00" + Digest::SHA256.digest("CHLO") + "SCFG")'
data := []byte("W\xA6\xFC\xDE\xC7\xD2>c\xE6\xB5\xF6\tq\x9E|<~1\xA33\x01\xCA=\x19\xBD\xC1\xE4\xB0\xBA\x9B\x16%")
err = rsa.VerifyPSS(key.PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey), crypto.SHA256, data, signature, &rsa.PSSOptions{SaltLength: 32})
Expect(err).ToNot(HaveOccurred())
})
Context("when using RSA", func() {
generateCert := func() (*rsa.PrivateKey, *x509.Certificate) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
Expect(err).NotTo(HaveOccurred())
certTemplate := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return key, cert
}
It("verifies a signature", func() {
key, cert := generateCert()
chlo := []byte("chlo")
scfg := []byte("scfg")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(signature, cert, chlo, scfg)).To(BeTrue())
})
It("rejects invalid signatures", func() {
key, cert := generateCert()
chlo := []byte("client hello")
scfg := []byte("sever config")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(append(signature, byte(0x99)), cert, chlo, scfg)).To(BeFalse())
Expect(verifyServerProof(signature, cert, chlo[:len(chlo)-2], scfg)).To(BeFalse())
Expect(verifyServerProof(signature, cert, chlo, scfg[:len(scfg)-2])).To(BeFalse())
})
})
Context("when using ECDSA", func() {
generateCert := func() (*ecdsa.PrivateKey, *x509.Certificate) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Expect(err).NotTo(HaveOccurred())
certTemplate := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
Expect(err).ToNot(HaveOccurred())
cert, err := x509.ParseCertificate(certDER)
Expect(err).ToNot(HaveOccurred())
return key, cert
}
It("gives valid signatures", func() {
key, _ := generateCert()
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, []byte{'C', 'H', 'L', 'O'}, []byte{'S', 'C', 'F', 'G'})
Expect(err).ToNot(HaveOccurred())
// Generated with:
// ruby -e 'require "digest"; p Digest::SHA256.digest("QUIC CHLO and server config signature\x00" + "\x20\x00\x00\x00" + Digest::SHA256.digest("CHLO") + "SCFG")'
data := []byte("W\xA6\xFC\xDE\xC7\xD2>c\xE6\xB5\xF6\tq\x9E|<~1\xA33\x01\xCA=\x19\xBD\xC1\xE4\xB0\xBA\x9B\x16%")
s := &ecdsaSignature{}
_, err = asn1.Unmarshal(signature, s)
Expect(err).NotTo(HaveOccurred())
b := ecdsa.Verify(key.Public().(*ecdsa.PublicKey), data, s.R, s.S)
Expect(b).To(BeTrue())
})
It("verifies a signature", func() {
key, cert := generateCert()
chlo := []byte("chlo")
scfg := []byte("server config")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(signature, cert, chlo, scfg)).To(BeTrue())
})
It("rejects invalid signatures", func() {
key, cert := generateCert()
chlo := []byte("client hello")
scfg := []byte("server config")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(append(signature, byte(0x99)), cert, chlo, scfg)).To(BeFalse())
Expect(verifyServerProof(signature, cert, chlo[:len(chlo)-2], scfg)).To(BeFalse())
Expect(verifyServerProof(signature, cert, chlo, scfg[:len(scfg)-2])).To(BeFalse())
})
It("rejects signatures generated with a different certificate", func() {
key1, cert1 := generateCert()
key2, cert2 := generateCert()
Expect(key1.PublicKey).ToNot(Equal(key2))
Expect(cert1.Equal(cert2)).To(BeFalse())
chlo := []byte("chlo")
scfg := []byte("sfcg")
signature, err := signServerProof(&tls.Certificate{PrivateKey: key1}, chlo, scfg)
Expect(err).ToNot(HaveOccurred())
Expect(verifyServerProof(signature, cert2, chlo, scfg)).To(BeFalse())
})
})
})

View file

@ -16,8 +16,7 @@ type streamFlowController struct {
queueWindowUpdate func()
connection connectionFlowControllerI
contributesToConnection bool // does the stream contribute to connection level flow control
connection connectionFlowControllerI
receivedFinalOffset bool
}
@ -27,7 +26,6 @@ var _ StreamFlowController = &streamFlowController{}
// NewStreamFlowController gets a new flow controller for a stream
func NewStreamFlowController(
streamID protocol.StreamID,
contributesToConnection bool,
cfc ConnectionFlowController,
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
@ -37,10 +35,9 @@ func NewStreamFlowController(
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
contributesToConnection: contributesToConnection,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
@ -87,32 +84,21 @@ func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCou
if c.checkFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
}
if c.contributesToConnection {
return c.connection.IncrementHighestReceived(increment)
}
return nil
return c.connection.IncrementHighestReceived(increment)
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
c.baseFlowController.AddBytesRead(n)
if c.contributesToConnection {
c.connection.AddBytesRead(n)
}
c.connection.AddBytesRead(n)
}
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
c.baseFlowController.AddBytesSent(n)
if c.contributesToConnection {
c.connection.AddBytesSent(n)
}
c.connection.AddBytesSent(n)
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
window := c.baseFlowController.sendWindowSize()
if c.contributesToConnection {
window = utils.MinByteCount(window, c.connection.SendWindowSize())
}
return window
return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
}
func (c *streamFlowController) MaybeQueueWindowUpdate() {
@ -122,9 +108,7 @@ func (c *streamFlowController) MaybeQueueWindowUpdate() {
if hasWindowUpdate {
c.queueWindowUpdate()
}
if c.contributesToConnection {
c.connection.MaybeQueueWindowUpdate()
}
c.connection.MaybeQueueWindowUpdate()
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
@ -140,9 +124,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
c.mutex.Unlock()
return offset

View file

@ -40,12 +40,11 @@ var _ = Describe("Stream Flow controller", func() {
It("sets the send and receive windows", func() {
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
Expect(fc.receiveWindow).To(Equal(receiveWindow))
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
Expect(fc.sendWindow).To(Equal(sendWindow))
Expect(fc.contributesToConnection).To(BeTrue())
})
It("queues window updates with the correction stream ID", func() {
@ -56,7 +55,7 @@ var _ = Describe("Stream Flow controller", func() {
}
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
fc.AddBytesRead(receiveWindow)
fc.MaybeQueueWindowUpdate()
Expect(queued).To(BeTrue())
@ -82,21 +81,12 @@ var _ = Describe("Stream Flow controller", func() {
It("informs the connection flow controller about received data", func() {
controller.highestReceived = 10
controller.contributesToConnection = true
controller.connection.(*connectionFlowController).highestReceived = 100
err := controller.UpdateHighestReceived(20, false)
Expect(err).ToNot(HaveOccurred())
Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100 + 10)))
})
It("doesn't informs the connection flow controller about received data if it doesn't contribute", func() {
controller.highestReceived = 10
controller.connection.(*connectionFlowController).highestReceived = 100
err := controller.UpdateHighestReceived(20, false)
Expect(err).ToNot(HaveOccurred())
Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100)))
})
It("does not decrease the highestReceived", func() {
controller.highestReceived = 1337
err := controller.UpdateHighestReceived(1000, false)
@ -111,6 +101,7 @@ var _ = Describe("Stream Flow controller", func() {
})
It("does not give a flow control violation when using the window completely", func() {
controller.connection.(*connectionFlowController).receiveWindow = receiveWindow
err := controller.UpdateHighestReceived(receiveWindow, false)
Expect(err).ToNot(HaveOccurred())
})
@ -163,19 +154,10 @@ var _ = Describe("Stream Flow controller", func() {
})
})
Context("registering data read", func() {
It("saves when data is read, on a stream not contributing to the connection", func() {
controller.AddBytesRead(100)
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(100)))
Expect(controller.connection.(*connectionFlowController).bytesRead).To(BeZero())
})
It("saves when data is read, on a stream not contributing to the connection", func() {
controller.contributesToConnection = true
controller.AddBytesRead(200)
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200)))
Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200)))
})
It("saves when data is read", func() {
controller.AddBytesRead(200)
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200)))
Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200)))
})
Context("generating window updates", func() {
@ -209,7 +191,6 @@ var _ = Describe("Stream Flow controller", func() {
})
It("queues connection-level window updates", func() {
controller.contributesToConnection = true
controller.MaybeQueueWindowUpdate()
Expect(queuedConnWindowUpdate).To(BeFalse())
controller.AddBytesRead(60)
@ -219,7 +200,6 @@ var _ = Describe("Stream Flow controller", func() {
It("tells the connection flow controller when the window was autotuned", func() {
oldOffset := controller.bytesRead
controller.contributesToConnection = true
setRtt(scaleDuration(20 * time.Millisecond))
controller.epochStartOffset = oldOffset
controller.epochStartTime = time.Now().Add(-time.Millisecond)
@ -230,19 +210,6 @@ var _ = Describe("Stream Flow controller", func() {
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)))
})
It("doesn't tell the connection flow controller if it doesn't contribute", func() {
oldOffset := controller.bytesRead
controller.contributesToConnection = false
setRtt(scaleDuration(20 * time.Millisecond))
controller.epochStartOffset = oldOffset
controller.epochStartTime = time.Now().Add(-time.Millisecond)
controller.AddBytesRead(55)
offset := controller.GetWindowUpdate()
Expect(offset).ToNot(BeZero())
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(2 * oldWindowSize))) // unchanged
})
It("doesn't increase the window after a final offset was already received", func() {
controller.AddBytesRead(30)
err := controller.UpdateHighestReceived(90, true)
@ -257,20 +224,13 @@ var _ = Describe("Stream Flow controller", func() {
Context("sending data", func() {
It("gets the size of the send window", func() {
controller.connection.UpdateSendWindow(1000)
controller.UpdateSendWindow(15)
controller.AddBytesSent(5)
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10)))
})
It("doesn't care about the connection-level window, if it doesn't contribute", func() {
controller.UpdateSendWindow(15)
controller.connection.UpdateSendWindow(1)
controller.AddBytesSent(5)
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10)))
})
It("makes sure that it doesn't overflow the connection-level window", func() {
controller.contributesToConnection = true
controller.connection.UpdateSendWindow(12)
controller.UpdateSendWindow(20)
controller.AddBytesSent(10)
@ -278,7 +238,6 @@ var _ = Describe("Stream Flow controller", func() {
})
It("doesn't say that it's blocked, if only the connection is blocked", func() {
controller.contributesToConnection = true
controller.connection.UpdateSendWindow(50)
controller.UpdateSendWindow(100)
controller.AddBytesSent(50)

View file

@ -46,7 +46,7 @@ func (m messageType) String() string {
}
}
type cryptoSetupTLS struct {
type cryptoSetup struct {
tlsConf *qtls.Config
messageChan chan []byte
@ -92,8 +92,8 @@ type cryptoSetupTLS struct {
perspective protocol.Perspective
}
var _ qtls.RecordLayer = &cryptoSetupTLS{}
var _ CryptoSetupTLS = &cryptoSetupTLS{}
var _ qtls.RecordLayer = &cryptoSetup{}
var _ CryptoSetup = &cryptoSetup{}
type versionInfo struct {
initialVersion protocol.VersionNumber
@ -101,8 +101,8 @@ type versionInfo struct {
currentVersion protocol.VersionNumber
}
// NewCryptoSetupTLSClient creates a new TLS crypto setup for the client
func NewCryptoSetupTLSClient(
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
@ -114,8 +114,8 @@ func NewCryptoSetupTLSClient(
currentVersion protocol.VersionNumber,
logger utils.Logger,
perspective protocol.Perspective,
) (CryptoSetupTLS, <-chan struct{} /* ClientHello written */, error) {
return newCryptoSetupTLS(
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
return newCryptoSetup(
initialStream,
handshakeStream,
connID,
@ -132,8 +132,8 @@ func NewCryptoSetupTLSClient(
)
}
// NewCryptoSetupTLSServer creates a new TLS crypto setup for the server
func NewCryptoSetupTLSServer(
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
@ -144,8 +144,8 @@ func NewCryptoSetupTLSServer(
currentVersion protocol.VersionNumber,
logger utils.Logger,
perspective protocol.Perspective,
) (CryptoSetupTLS, error) {
cs, _, err := newCryptoSetupTLS(
) (CryptoSetup, error) {
cs, _, err := newCryptoSetup(
initialStream,
handshakeStream,
connID,
@ -162,7 +162,7 @@ func NewCryptoSetupTLSServer(
return cs, err
}
func newCryptoSetupTLS(
func newCryptoSetup(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
@ -172,12 +172,12 @@ func newCryptoSetupTLS(
versionInfo versionInfo,
logger utils.Logger,
perspective protocol.Perspective,
) (CryptoSetupTLS, <-chan struct{} /* ClientHello written */, error) {
initialAEAD, err := crypto.NewNullAEAD(perspective, connID, protocol.VersionTLS)
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
initialAEAD, err := crypto.NewNullAEAD(connID, perspective)
if err != nil {
return nil, nil, err
}
cs := &cryptoSetupTLS{
cs := &cryptoSetup{
initialStream: initialStream,
initialAEAD: initialAEAD,
handshakeStream: handshakeStream,
@ -221,7 +221,7 @@ func newCryptoSetupTLS(
return cs, cs.clientHelloWrittenChan, nil
}
func (h *cryptoSetupTLS) RunHandshake() error {
func (h *cryptoSetup) RunHandshake() error {
var conn *qtls.Conn
switch h.perspective {
case protocol.PerspectiveClient:
@ -264,7 +264,7 @@ func (h *cryptoSetupTLS) RunHandshake() error {
}
}
func (h *cryptoSetupTLS) Close() error {
func (h *cryptoSetup) Close() error {
close(h.closeChan)
// wait until qtls.Handshake() actually returned
<-h.handshakeDone
@ -274,7 +274,7 @@ func (h *cryptoSetupTLS) Close() error {
// handleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
// It returns if it is done with messages on the same encryption level.
func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
@ -292,7 +292,7 @@ func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.Encryption
}
}
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
var expected protocol.EncryptionLevel
switch msgType {
case typeClientHello,
@ -313,7 +313,7 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
return nil
}
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
switch msgType {
case typeClientHello:
select {
@ -358,7 +358,7 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
}
}
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
switch msgType {
case typeServerHello:
// get the handshake read key
@ -408,7 +408,7 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
// TODO: add some error handling here (when the session is closed)
msg, ok := <-h.messageChan
if !ok {
@ -417,7 +417,7 @@ func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
return msg, nil
}
func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
opener := newOpener(suite.AEAD(key, iv), iv)
@ -437,7 +437,7 @@ func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byt
h.receivedReadKey <- struct{}{}
}
func (h *cryptoSetupTLS) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
sealer := newSealer(suite.AEAD(key, iv), iv)
@ -458,7 +458,7 @@ func (h *cryptoSetupTLS) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []by
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetupTLS) WriteRecord(p []byte) (int, error) {
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
switch h.writeEncLevel {
case protocol.EncryptionInitial:
// assume that the first WriteRecord call contains the ClientHello
@ -475,7 +475,7 @@ func (h *cryptoSetupTLS) WriteRecord(p []byte) (int, error) {
}
}
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
if h.sealer != nil {
return protocol.Encryption1RTT, h.sealer
}
@ -485,7 +485,7 @@ func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionInitial, h.initialAEAD
}
func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
switch level {
@ -506,25 +506,25 @@ func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(level protocol.EncryptionL
}
}
func (h *cryptoSetupTLS) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
return h.initialAEAD.Open(dst, src, pn, ad)
}
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
if h.handshakeOpener == nil {
return nil, errors.New("no handshake opener")
}
return h.handshakeOpener.Open(dst, src, pn, ad)
}
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
if h.opener == nil {
return nil, errors.New("no 1-RTT opener")
}
return h.opener.Open(dst, src, pn, ad)
}
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
func (h *cryptoSetup) ConnectionState() ConnectionState {
// TODO: return the connection state
return ConnectionState{}
}

View file

@ -1,545 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type cryptoSetupClient struct {
mutex sync.RWMutex
hostname string
connID protocol.ConnectionID
version protocol.VersionNumber
initialVersion protocol.VersionNumber
negotiatedVersions []protocol.VersionNumber
cryptoStream io.ReadWriter
serverConfig *serverConfigClient
stk []byte
sno []byte
nonc []byte
proof []byte
chloForSignature []byte
lastSentCHLO []byte
certManager crypto.CertManager
divNonceChan chan struct{}
diversificationNonce []byte
clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation QuicCryptoKeyDerivationFunction
receivedSecurePacket bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{}
params *TransportParameters
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupClient{}
var (
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available")
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated")
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces")
)
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
func NewCryptoSetupClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
version protocol.VersionNumber,
tlsConf *tls.Config,
params *TransportParameters,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
divNonceChan := make(chan struct{})
cs := &cryptoSetupClient{
cryptoStream: cryptoStream,
hostname: tlsConf.ServerName,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConf),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
initialVersion: initialVersion,
// The server might have sent greased versions in the Version Negotiation packet.
// We need strip those from the list, since they won't be included in the handshake tag.
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
divNonceChan: divNonceChan,
logger: logger,
}
return cs, nil
}
func (h *cryptoSetupClient) RunHandshake() error {
messageChan := make(chan HandshakeMessage)
errorChan := make(chan error, 1)
go func() {
for {
message, err := ParseHandshakeMessage(h.cryptoStream)
if err != nil {
errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error())
return
}
messageChan <- message
}
}()
for {
if err := h.maybeUpgradeCrypto(); err != nil {
return err
}
h.mutex.RLock()
sendCHLO := h.secureAEAD == nil
h.mutex.RUnlock()
if sendCHLO {
if err := h.sendCHLO(); err != nil {
return err
}
}
var message HandshakeMessage
select {
case <-h.divNonceChan:
// there's no message to process, but we should try upgrading the crypto again
continue
case message = <-messageChan:
case err := <-errorChan:
return err
}
h.logger.Debugf("Got %s", message)
switch message.Tag {
case TagREJ:
if err := h.handleREJMessage(message.Data); err != nil {
return err
}
case TagSHLO:
params, err := h.handleSHLOMessage(message.Data)
if err != nil {
return err
}
// blocks until the session has received the parameters
h.paramsChan <- *params
h.handshakeEvent <- struct{}{}
close(h.handshakeComplete)
default:
return qerr.InvalidCryptoMessageType
}
}
}
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
var err error
if stk, ok := cryptoData[TagSTK]; ok {
h.stk = stk
}
if sno, ok := cryptoData[TagSNO]; ok {
h.sno = sno
}
// TODO: what happens if the server sends a different server config in two packets?
if scfg, ok := cryptoData[TagSCFG]; ok {
h.serverConfig, err = parseServerConfig(scfg)
if err != nil {
return err
}
if h.serverConfig.IsExpired() {
return qerr.CryptoServerConfigExpired
}
// now that we have a server config, we can use its OBIT value to generate a client nonce
if len(h.nonc) == 0 {
err = h.generateClientNonce()
if err != nil {
return err
}
}
}
if proof, ok := cryptoData[TagPROF]; ok {
h.proof = proof
h.chloForSignature = h.lastSentCHLO
}
if crt, ok := cryptoData[TagCERT]; ok {
err := h.certManager.SetData(crt)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
err = h.certManager.Verify(h.hostname)
if err != nil {
h.logger.Infof("Certificate validation failed: %s", err.Error())
return qerr.ProofInvalid
}
}
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
if !validProof {
h.logger.Infof("Server proof verification failed")
return qerr.ProofInvalid
}
h.serverVerified = true
}
return nil
}
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.receivedSecurePacket {
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
}
if sno, ok := cryptoData[TagSNO]; ok {
h.sno = sno
}
serverPubs, ok := cryptoData[TagPUBS]
if !ok {
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
verTag, ok := cryptoData[TagVER]
if !ok {
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
}
if !h.validateVersionList(verTag) {
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
nonce := append(h.nonc, h.sno...)
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
if err != nil {
return nil, err
}
leafCert := h.certManager.GetLeafCert()
h.forwardSecureAEAD, err = h.keyDerivation(
true,
ephermalSharedSecret,
nonce,
h.connID,
h.lastSentCHLO,
h.serverConfig.Get(),
leafCert,
nil,
protocol.PerspectiveClient,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
params, err := readHelloMap(cryptoData)
if err != nil {
return nil, qerr.InvalidCryptoMessageParameter
}
return params, nil
}
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
numNegotiatedVersions := len(h.negotiatedVersions)
if numNegotiatedVersions == 0 {
return true
}
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
return false
}
b := bytes.NewReader(verTags)
for i := 0; i < numNegotiatedVersions; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil { // should never occur, since the length was already checked
return false
}
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
return false
}
}
return true
}
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
return data, protocol.EncryptionForwardSecure, nil
}
return nil, protocol.EncryptionUnspecified, err
}
if h.secureAEAD != nil {
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return data, protocol.EncryptionSecure, nil
}
if h.receivedSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, nil
}
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
} else if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.secureAEAD
} else {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
}
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no secureAEAD")
}
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
}
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupClient: no encryption level specified")
}
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
HandshakeComplete: h.forwardSecureAEAD != nil,
PeerCertificates: h.certManager.GetChain(),
}
}
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
h.mutex.Lock()
if len(h.diversificationNonce) > 0 {
defer h.mutex.Unlock()
if !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
}
return nil
}
h.diversificationNonce = divNonce
h.mutex.Unlock()
h.divNonceChan <- struct{}{}
return nil
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos))
}
b := &bytes.Buffer{}
tags, err := h.getTags()
if err != nil {
return err
}
h.addPadding(tags)
message := HandshakeMessage{
Tag: TagCHLO,
Data: tags,
}
h.logger.Debugf("Sending %s", message)
message.Write(b)
_, err = h.cryptoStream.Write(b.Bytes())
if err != nil {
return err
}
h.lastSentCHLO = b.Bytes()
return nil
}
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
tags := h.params.getHelloMap()
tags[TagSNI] = []byte(h.hostname)
tags[TagPDMD] = []byte("X509")
ccs := h.certManager.GetCommonCertificateHashes()
if len(ccs) > 0 {
tags[TagCCS] = ccs
}
versionTag := make([]byte, 4)
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
tags[TagVER] = versionTag
if len(h.stk) > 0 {
tags[TagSTK] = h.stk
}
if len(h.sno) > 0 {
tags[TagSNO] = h.sno
}
if h.serverConfig != nil {
tags[TagSCID] = h.serverConfig.ID
leafCert := h.certManager.GetLeafCert()
if leafCert != nil {
certHash, _ := h.certManager.GetLeafCertHash()
xlct := make([]byte, 8)
binary.LittleEndian.PutUint64(xlct, certHash)
tags[TagNONC] = h.nonc
tags[TagXLCT] = xlct
tags[TagKEXS] = []byte("C255")
tags[TagAEAD] = []byte("AESG")
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
}
}
return tags, nil
}
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
var size int
for _, tag := range tags {
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
}
paddingSize := protocol.MinClientHelloSize - size
if paddingSize > 0 {
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
}
}
func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if !h.serverVerified {
return nil
}
h.mutex.Lock()
defer h.mutex.Unlock()
leafCert := h.certManager.GetLeafCert()
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) {
var err error
var nonce []byte
if h.sno == nil {
nonce = h.nonc
} else {
nonce = append(h.nonc, h.sno...)
}
h.secureAEAD, err = h.keyDerivation(
false,
h.serverConfig.sharedSecret,
nonce,
h.connID,
h.lastSentCHLO,
h.serverConfig.Get(),
leafCert,
h.diversificationNonce,
protocol.PerspectiveClient,
)
if err != nil {
return err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
}
return nil
}
func (h *cryptoSetupClient) generateClientNonce() error {
if len(h.nonc) > 0 {
return errClientNonceAlreadyExists
}
nonc := make([]byte, 32)
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix()))
if len(h.serverConfig.obit) != 8 {
return errNoObitForClientNonce
}
copy(nonc[4:12], h.serverConfig.obit)
_, err := rand.Read(nonc[12:])
if err != nil {
return err
}
h.nonc = nonc
return nil
}

File diff suppressed because it is too large Load diff

View file

@ -1,470 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// QuicCryptoKeyDerivationFunction is used for key derivation
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
// KeyExchangeFunction is used to make a new KEX
type KeyExchangeFunction func() (crypto.KeyExchange, error)
// The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct {
mutex sync.RWMutex
connID protocol.ConnectionID
remoteAddr net.Addr
scfg *ServerConfig
diversificationNonce []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
acceptSTKCallback func(net.Addr, *Cookie) bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
receivedForwardSecurePacket bool
receivedSecurePacket bool
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
receivedParams bool
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
handshakeComplete chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
cryptoStream io.ReadWriter
params *TransportParameters
sni string // need to fill out the ConnectionState
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupServer{}
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
// NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
remoteAddr net.Addr,
version protocol.VersionNumber,
divNonce []byte,
scfg *ServerConfig,
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
handshakeComplete chan<- struct{},
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupServer{
cryptoStream: cryptoStream,
connID: connID,
remoteAddr: remoteAddr,
version: version,
supportedVersions: supportedVersions,
diversificationNonce: divNonce,
scfg: scfg,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
params: params,
acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
handshakeComplete: handshakeComplete,
logger: logger,
}, nil
}
// HandleCryptoStream reads and writes messages on the crypto stream
func (h *cryptoSetupServer) RunHandshake() error {
for {
var chloData bytes.Buffer
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
if err != nil {
return qerr.HandshakeFailed
}
if message.Tag != TagCHLO {
return qerr.InvalidCryptoMessageType
}
h.logger.Debugf("Got %s", message)
done, err := h.handleMessage(chloData.Bytes(), message.Data)
if err != nil {
return err
}
if done {
return nil
}
}
}
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
return false, ErrNSTPExperiment
}
sniSlice, ok := cryptoData[TagSNI]
if !ok {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
sni := string(sniSlice)
if sni == "" {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
h.sni = sni
// prevent version downgrade attacks
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
verSlice, ok := cryptoData[TagVER]
if !ok {
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")
}
if len(verSlice) != 4 {
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")
}
ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice))
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
var reply []byte
var err error
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
if err != nil {
return false, err
}
params, err := readHelloMap(cryptoData)
if err != nil {
return false, err
}
// blocks until the session has received the parameters
if !h.receivedParams {
h.receivedParams = true
h.paramsChan <- *params
}
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
reply, err = h.handleCHLO(sni, chloData, cryptoData)
if err != nil {
return false, err
}
if _, err := h.cryptoStream.Write(reply); err != nil {
return false, err
}
h.handshakeEvent <- struct{}{}
close(h.sentSHLO)
return true, nil
}
// We have an inchoate or non-matching CHLO, we now send a rejection
reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData)
if err != nil {
return false, err
}
_, err = h.cryptoStream.Write(reply)
return false, err
}
// Open a message
func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
h.receivedForwardSecurePacket = true
// wait for the send on the handshakeEvent chan
<-h.sentSHLO
close(h.handshakeComplete)
}
return res, protocol.EncryptionForwardSecure, nil
}
if h.receivedForwardSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
if h.secureAEAD != nil {
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return res, protocol.EncryptionSecure, nil
}
if h.receivedSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return res, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, err
}
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.secureAEAD
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no secureAEAD")
}
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
}
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupServer: no encryption level specified")
}
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
if _, ok := cryptoData[TagPUBS]; !ok {
return true
}
scid, ok := cryptoData[TagSCID]
if !ok || !bytes.Equal(h.scfg.ID, scid) {
return true
}
xlctTag, ok := cryptoData[TagXLCT]
if !ok || len(xlctTag) != 8 {
return true
}
xlct := binary.LittleEndian.Uint64(xlctTag)
if crypto.HashCert(cert) != xlct {
return true
}
return !h.acceptSTK(cryptoData[TagSTK])
}
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
if err != nil {
h.logger.Debugf("STK invalid: %s", err.Error())
return false
}
return h.acceptSTKCallback(h.remoteAddr, stk)
}
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
if err != nil {
return nil, err
}
replyMap := map[Tag][]byte{
TagSCFG: h.scfg.Get(),
TagSTK: token,
TagSVID: []byte("quic-go"),
}
if h.acceptSTK(cryptoData[TagSTK]) {
proof, err := h.scfg.Sign(sni, chlo)
if err != nil {
return nil, err
}
commonSetHashes := cryptoData[TagCCS]
cachedCertsHashes := cryptoData[TagCCRT]
certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes)
if err != nil {
return nil, err
}
// Token was valid, send more details
replyMap[TagPROF] = proof
replyMap[TagCERT] = certCompressed
}
message := HandshakeMessage{
Tag: TagREJ,
Data: replyMap,
}
var serverReply bytes.Buffer
message.Write(&serverReply)
h.logger.Debugf("Sending %s", message)
return serverReply.Bytes(), nil
}
func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) {
// We have a CHLO matching our server config, we can continue with the 0-RTT handshake
sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
}
h.mutex.Lock()
defer h.mutex.Unlock()
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
if err != nil {
return nil, err
}
serverNonce := make([]byte, 32)
if _, err = rand.Read(serverNonce); err != nil {
return nil, err
}
clientNonce := cryptoData[TagNONC]
err = h.validateClientNonce(clientNonce)
if err != nil {
return nil, err
}
aead := cryptoData[TagAEAD]
if !bytes.Equal(aead, []byte("AESG")) {
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
}
kexs := cryptoData[TagKEXS]
if !bytes.Equal(kexs, []byte("C255")) {
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
}
h.secureAEAD, err = h.keyDerivation(
false,
sharedSecret,
clientNonce,
h.connID,
data,
h.scfg.Get(),
certUncompressed,
h.diversificationNonce,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
// Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer
fsNonce.Write(clientNonce)
fsNonce.Write(serverNonce)
ephermalKex, err := h.keyExchange()
if err != nil {
return nil, err
}
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
}
h.forwardSecureAEAD, err = h.keyDerivation(
true,
ephermalSharedSecret,
fsNonce.Bytes(),
h.connID,
data,
h.scfg.Get(),
certUncompressed,
nil,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
replyMap := h.params.getHelloMap()
// add crypto parameters
verTag := &bytes.Buffer{}
for _, v := range h.supportedVersions {
utils.BigEndian.WriteUint32(verTag, uint32(v))
}
replyMap[TagPUBS] = ephermalKex.PublicKey()
replyMap[TagSNO] = serverNonce
replyMap[TagVER] = verTag.Bytes()
// note that the SHLO *has* to fit into one packet
message := HandshakeMessage{
Tag: TagSHLO,
Data: replyMap,
}
var reply bytes.Buffer
message.Write(&reply)
h.logger.Debugf("Sending %s", message)
return reply.Bytes(), nil
}
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
ServerName: h.sni,
HandshakeComplete: h.receivedForwardSecurePacket,
}
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
}
if !bytes.Equal(nonce[4:12], h.scfg.obit) {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")
}
return nil
}

View file

@ -1,734 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
type mockKEX struct {
ephermal bool
sharedKeyError error
}
func (m *mockKEX) PublicKey() []byte {
if m.ephermal {
return []byte("ephermal pub")
}
return []byte("initial public")
}
func (m *mockKEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
if m.sharedKeyError != nil {
return nil, m.sharedKeyError
}
if m.ephermal {
return []byte("shared ephermal"), nil
}
return []byte("shared key"), nil
}
type mockSigner struct {
gotCHLO bool
}
func (s *mockSigner) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
if len(chlo) > 0 {
s.gotCHLO = true
}
return []byte("proof"), nil
}
func (*mockSigner) GetCertsCompressed(sni string, common, cached []byte) ([]byte, error) {
return []byte("certcompressed"), nil
}
func (*mockSigner) GetLeafCert(sni string) ([]byte, error) {
return []byte("certuncompressed"), nil
}
func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
return mockcrypto.NewMockAEAD(mockCtrl), nil
}
type mockStream struct {
unblockRead chan struct{}
dataToRead bytes.Buffer
dataWritten bytes.Buffer
}
var _ io.ReadWriter = &mockStream{}
var errMockStreamClosing = errors.New("mock stream closing")
func newMockStream() *mockStream {
return &mockStream{unblockRead: make(chan struct{})}
}
// call Close to make Read return
func (s *mockStream) Read(p []byte) (int, error) {
n, _ := s.dataToRead.Read(p)
if n == 0 { // block if there's no data
<-s.unblockRead
return 0, errMockStreamClosing
}
return n, nil // never return an EOF
}
func (s *mockStream) Write(p []byte) (int, error) {
return s.dataWritten.Write(p)
}
func (s *mockStream) close() {
close(s.unblockRead)
}
type mockCookieProtector struct {
decodeErr error
}
var _ cookieProtector = &mockCookieProtector{}
func (mockCookieProtector) NewToken(sourceAddr []byte) ([]byte, error) {
return append([]byte("token "), sourceAddr...), nil
}
func (s mockCookieProtector) DecodeToken(data []byte) ([]byte, error) {
if s.decodeErr != nil {
return nil, s.decodeErr
}
if len(data) < 6 {
return nil, errors.New("token too short")
}
return data[6:], nil
}
var _ = Describe("Server Crypto Setup", func() {
var (
kex *mockKEX
signer *mockSigner
scfg *ServerConfig
cs *cryptoSetupServer
stream *mockStream
paramsChan chan TransportParameters
handshakeEvent chan struct{}
handshakeComplete chan struct{}
nonce32 []byte
versionTag []byte
validSTK []byte
aead []byte
kexs []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
sourceAddrValid bool
)
const (
expectedInitialNonceLen = 32
expectedFSNonceLen = 64
)
BeforeEach(func() {
var err error
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
paramsChan = make(chan TransportParameters, 1)
handshakeEvent = make(chan struct{}, 2)
handshakeComplete = make(chan struct{})
stream = newMockStream()
kex = &mockKEX{}
signer = &mockSigner{}
scfg, err = NewServerConfig(kex, signer)
nonce32 = make([]byte, 32)
aead = []byte("AESG")
kexs = []byte("C255")
copy(nonce32[4:12], scfg.obit) // set the OBIT value at the right position
versionTag = make([]byte, 4)
binary.BigEndian.PutUint32(versionTag, uint32(protocol.VersionWhatever))
Expect(err).NotTo(HaveOccurred())
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
supportedVersions = []protocol.VersionNumber{version, 98, 99}
csInt, err := NewCryptoSetup(
stream,
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
remoteAddr,
version,
make([]byte, 32), // div nonce
scfg,
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
supportedVersions,
nil,
paramsChan,
handshakeEvent,
handshakeComplete,
utils.DefaultLogger,
)
Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer)
cs.scfg.cookieGenerator.cookieProtector = &mockCookieProtector{}
validSTK, err = cs.scfg.cookieGenerator.NewToken(remoteAddr)
Expect(err).NotTo(HaveOccurred())
sourceAddrValid = true
cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid }
cs.keyDerivation = mockQuicCryptoKeyDerivation
cs.keyExchange = func() (crypto.KeyExchange, error) { return &mockKEX{ephermal: true}, nil }
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
cs.cryptoStream = stream
})
Context("when responding to client messages", func() {
var cert []byte
var xlct []byte
var fullCHLO map[Tag][]byte
BeforeEach(func() {
xlct = make([]byte, 8)
var err error
cert, err = cs.scfg.certChain.GetLeafCert("")
Expect(err).ToNot(HaveOccurred())
binary.LittleEndian.PutUint64(xlct, crypto.HashCert(cert))
fullCHLO = map[Tag][]byte{
TagSCID: scfg.ID,
TagSNI: []byte("quic.clemente.io"),
TagNONC: nonce32,
TagSTK: validSTK,
TagXLCT: xlct,
TagAEAD: aead,
TagKEXS: kexs,
TagPUBS: bytes.Repeat([]byte{'e'}, 31),
TagVER: versionTag,
}
})
It("doesn't support Chrome's no STOP_WAITING experiment", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagNSTP: []byte("foobar"),
},
}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(ErrNSTPExperiment))
})
It("reads the transport parameters sent by the client", func() {
sourceAddrValid = true
fullCHLO[TagICSL] = []byte{0x37, 0x13, 0, 0}
_, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), fullCHLO)
Expect(err).ToNot(HaveOccurred())
var params TransportParameters
Expect(paramsChan).To(Receive(&params))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
})
It("generates REJ messages", func() {
sourceAddrValid = false
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), nil)
Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("REJ"))
Expect(response).To(ContainSubstring("initial public"))
Expect(response).ToNot(ContainSubstring("certcompressed"))
Expect(response).ToNot(ContainSubstring("proof"))
Expect(signer.gotCHLO).To(BeFalse())
})
It("REJ messages don't include cert or proof without STK", func() {
sourceAddrValid = false
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), nil)
Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("REJ"))
Expect(response).ToNot(ContainSubstring("certcompressed"))
Expect(response).ToNot(ContainSubstring("proof"))
Expect(signer.gotCHLO).To(BeFalse())
})
It("REJ messages include cert and proof with valid STK", func() {
sourceAddrValid = true
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), map[Tag][]byte{
TagSTK: validSTK,
TagSNI: []byte("foo"),
})
Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("REJ"))
Expect(response).To(ContainSubstring("certcompressed"))
Expect(response).To(ContainSubstring("proof"))
Expect(signer.gotCHLO).To(BeTrue())
})
It("generates SHLO messages", func() {
var checkedSecure, checkedForwardSecure bool
cs.keyDerivation = func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
if forwardSecure {
Expect(nonces).To(HaveLen(expectedFSNonceLen))
checkedForwardSecure = true
Expect(sharedSecret).To(Equal([]byte("shared ephermal")))
} else {
Expect(nonces).To(HaveLen(expectedInitialNonceLen))
Expect(sharedSecret).To(Equal([]byte("shared key")))
checkedSecure = true
}
return mockcrypto.NewMockAEAD(mockCtrl), nil
}
response, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
TagPUBS: []byte("pubs-c"),
TagNONC: nonce32,
TagAEAD: aead,
TagKEXS: kexs,
})
Expect(err).ToNot(HaveOccurred())
Expect(response).To(HavePrefix("SHLO"))
message, err := ParseHandshakeMessage(bytes.NewReader(response))
Expect(err).ToNot(HaveOccurred())
Expect(message.Data).To(HaveKeyWithValue(TagPUBS, []byte("ephermal pub")))
Expect(message.Data).To(HaveKey(TagSNO))
Expect(message.Data).To(HaveKey(TagVER))
Expect(message.Data[TagVER]).To(HaveLen(4 * len(supportedVersions)))
for _, v := range supportedVersions {
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(v))
Expect(message.Data[TagVER]).To(ContainSubstring(b.String()))
}
Expect(checkedSecure).To(BeTrue())
Expect(checkedForwardSecure).To(BeTrue())
})
It("handles long handshake", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagSNI: []byte("quic.clemente.io"),
TagSTK: validSTK,
TagPAD: bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
TagVER: versionTag,
},
}.Write(&stream.dataToRead)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeComplete).ToNot(BeClosed())
})
It("rejects client nonces that have the wrong length", func() {
fullCHLO[TagNONC] = []byte("too short client nonce")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")))
})
It("rejects client nonces that have the wrong OBIT value", func() {
fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")))
})
It("errors if it can't calculate a shared key", func() {
testErr := errors.New("test error")
kex.sharedKeyError = testErr
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(testErr))
})
It("handles 0-RTT handshake", func() {
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).NotTo(HaveOccurred())
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
Expect(handshakeEvent).To(Receive()) // for the switch to secure
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
Expect(handshakeComplete).ToNot(BeClosed())
})
It("recognizes inchoate CHLOs missing SCID", func() {
delete(fullCHLO, TagSCID)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs missing PUBS", func() {
delete(fullCHLO, TagPUBS)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs with missing XLCT", func() {
delete(fullCHLO, TagXLCT)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs with wrong length XLCT", func() {
fullCHLO[TagXLCT] = bytes.Repeat([]byte{'f'}, 7) // should be 8 bytes
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs with wrong XLCT", func() {
fullCHLO[TagXLCT] = bytes.Repeat([]byte{'f'}, 8)
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes inchoate CHLOs with an invalid STK", func() {
testErr := errors.New("STK invalid")
cs.scfg.cookieGenerator.cookieProtector.(*mockCookieProtector).decodeErr = testErr
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
})
It("recognizes proper CHLOs", func() {
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
})
It("rejects CHLOs without the version tag", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagSCID: scfg.ID,
TagSNI: []byte("quic.clemente.io"),
},
}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")))
})
It("rejects CHLOs with a version tag that has the wrong length", func() {
fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")))
})
It("detects version downgrade attacks", func() {
highestSupportedVersion := supportedVersions[len(supportedVersions)-1]
lowestSupportedVersion := supportedVersions[0]
Expect(highestSupportedVersion).ToNot(Equal(lowestSupportedVersion))
cs.version = highestSupportedVersion
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(lowestSupportedVersion))
fullCHLO[TagVER] = b
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")))
})
It("accepts a non-matching version tag in the CHLO, if it is an unsupported version", func() {
supportedVersion := protocol.SupportedVersions[0]
unsupportedVersion := supportedVersion + 1000
Expect(protocol.IsSupportedVersion(supportedVersions, unsupportedVersion)).To(BeFalse())
cs.version = supportedVersion
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(unsupportedVersion))
fullCHLO[TagVER] = b
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).ToNot(HaveOccurred())
})
It("errors if the AEAD tag is missing", func() {
delete(fullCHLO, TagAEAD)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
})
It("errors if the AEAD tag has the wrong value", func() {
fullCHLO[TagAEAD] = []byte("wrong")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
})
It("errors if the KEXS tag is missing", func() {
delete(fullCHLO, TagKEXS)
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
})
It("errors if the KEXS tag has the wrong value", func() {
fullCHLO[TagKEXS] = []byte("wrong")
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
})
})
It("errors without SNI", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagSTK: validSTK,
},
}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
})
It("errors with empty SNI", func() {
HandshakeMessage{
Tag: TagCHLO,
Data: map[Tag][]byte{
TagSTK: validSTK,
TagSNI: nil,
},
}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
})
It("errors with invalid message", func() {
stream.dataToRead.Write([]byte("invalid message"))
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.HandshakeFailed))
})
It("errors with non-CHLO message", func() {
HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead)
err := cs.RunHandshake()
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
})
Context("escalating crypto", func() {
doCHLO := func() {
_, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
TagPUBS: []byte("pubs-c"),
TagNONC: nonce32,
TagAEAD: aead,
TagKEXS: kexs,
})
Expect(err).ToNot(HaveOccurred())
Expect(handshakeEvent).To(Receive()) // for the switch to secure
close(cs.sentSHLO)
}
Context("null encryption", func() {
It("is used initially", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(10), []byte{}).Return([]byte("foobar signed"))
enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
d := sealer.Seal(nil, []byte("foobar"), 10, []byte{})
Expect(d).To(Equal([]byte("foobar signed")))
})
It("is used for the crypto stream", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
})
It("is accepted initially", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(5), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("unencrypted"), 5, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
})
It("errors if the has the wrong hash", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("not unencrypted"), protocol.PacketNumber(5), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err := cs.Open(nil, []byte("not unencrypted"), 5, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is still accepted after CHLO", func() {
doCHLO()
// it tries forward secure and secure decryption first
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{})
Expect(cs.secureAEAD).ToNot(BeNil())
_, enc, err := cs.Open(nil, []byte("unencrypted"), 99, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
})
It("is not accepted after receiving secure packet", func() {
doCHLO()
// first receive a secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
Expect(enc).To(Equal(protocol.EncryptionSecure))
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
// now receive an unencrypted packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err = cs.Open(nil, []byte("unencrypted"), 99, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is not used after CHLO", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
enc, sealer := cs.GetSealer()
Expect(enc).ToNot(Equal(protocol.EncryptionUnencrypted))
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
})
})
Context("initial encryption", func() {
It("is accepted after CHLO", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
Expect(enc).To(Equal(protocol.EncryptionSecure))
Expect(err).ToNot(HaveOccurred())
Expect(d).To(Equal([]byte("decrypted")))
})
It("is not accepted after receiving forward secure packet", func() {
doCHLO()
// receive a forward secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
Expect(err).ToNot(HaveOccurred())
// receive a secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(12), []byte{}).Return(nil, errors.New("authentication failed"))
_, enc, err := cs.Open(nil, []byte("encrypted"), 12, []byte{})
Expect(err).To(MatchError("authentication failed"))
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
})
It("is used for the crypto stream", func() {
doCHLO()
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(1), []byte{}).Return([]byte("foobar crypto stream"))
enc, sealer := cs.GetSealerForCryptoStream()
Expect(enc).To(Equal(protocol.EncryptionSecure))
d := sealer.Seal(nil, []byte("foobar"), 1, []byte{})
Expect(d).To(Equal([]byte("foobar crypto stream")))
})
})
Context("forward secure encryption", func() {
It("is used after the CHLO", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(20), []byte{}).Return([]byte("foobar forward sec"))
enc, sealer := cs.GetSealer()
Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
d := sealer.Seal(nil, []byte("foobar"), 20, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec")))
})
It("regards the handshake as complete once it receives a forward encrypted packet", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
Expect(err).ToNot(HaveOccurred())
Expect(handshakeComplete).To(BeClosed())
})
})
Context("reporting the connection state", func() {
It("reports before the handshake completes", func() {
cs.sni = "server name"
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeFalse())
Expect(state.ServerName).To(Equal("server name"))
})
It("reports after the handshake completes", func() {
doCHLO()
// receive a forward secure packet
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
Expect(err).ToNot(HaveOccurred())
state := cs.ConnectionState()
Expect(state.HandshakeComplete).To(BeTrue())
})
})
Context("forcing encryption levels", func() {
It("forces null encryption", func() {
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(11), []byte{}).Return([]byte("foobar unencrypted"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 11, []byte{})
Expect(d).To(Equal([]byte("foobar unencrypted")))
})
It("forces initial encryption", func() {
doCHLO()
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(12), []byte{}).Return([]byte("foobar secure"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 12, []byte{})
Expect(d).To(Equal([]byte("foobar secure")))
})
It("errors if no AEAD for initial encryption is available", func() {
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
Expect(err).To(MatchError("CryptoSetupServer: no secureAEAD"))
Expect(sealer).To(BeNil())
})
It("forces forward-secure encryption", func() {
doCHLO()
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(13), []byte{}).Return([]byte("foobar forward sec"))
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).ToNot(HaveOccurred())
d := sealer.Seal(nil, []byte("foobar"), 13, []byte{})
Expect(d).To(Equal([]byte("foobar forward sec")))
})
It("errors of no AEAD for forward-secure encryption is available", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
Expect(err).To(MatchError("CryptoSetupServer: no forwardSecureAEAD"))
Expect(seal).To(BeNil())
})
It("errors if no encryption level is specified", func() {
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified)
Expect(err).To(MatchError("CryptoSetupServer: no encryption level specified"))
Expect(seal).To(BeNil())
})
})
})
Context("STK verification and creation", func() {
It("requires STK", func() {
sourceAddrValid = false
done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
map[Tag][]byte{
TagSNI: []byte("foo"),
TagVER: versionTag,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
Expect(cs.sni).To(Equal("foo"))
})
It("works with proper STK", func() {
sourceAddrValid = true
done, err := cs.handleMessage(
bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
map[Tag][]byte{
TagSNI: []byte("foo"),
TagVER: versionTag,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(done).To(BeFalse())
})
})
})

View file

@ -57,7 +57,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when an error occurs", func() {
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer(
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
@ -87,7 +87,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when handling a message fails", func() {
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer(
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
@ -116,7 +116,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("returns Handshake() when it is closed", func() {
_, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer(
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
@ -162,9 +162,9 @@ var _ = Describe("Crypto Setup TLS", func() {
}
handshake := func(
client CryptoSetupTLS,
client CryptoSetup,
cChunkChan <-chan chunk,
server CryptoSetupTLS,
server CryptoSetup,
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
done := make(chan struct{})
defer close(done)
@ -195,7 +195,7 @@ var _ = Describe("Crypto Setup TLS", func() {
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, _, err := NewCryptoSetupTLSClient(
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
@ -211,7 +211,7 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(err).ToNot(HaveOccurred())
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server, err := NewCryptoSetupTLSServer(
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},
@ -250,7 +250,7 @@ var _ = Describe("Crypto Setup TLS", func() {
It("signals when it has written the ClientHello", func() {
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
client, chChan, err := NewCryptoSetupTLSClient(
client, chChan, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
@ -289,7 +289,7 @@ var _ = Describe("Crypto Setup TLS", func() {
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
client, _, err := NewCryptoSetupTLSClient(
client, _, err := NewCryptoSetupClient(
cInitialStream,
cHandshakeStream,
protocol.ConnectionID{},
@ -309,7 +309,7 @@ var _ = Describe("Crypto Setup TLS", func() {
IdleTimeout: 0x1337 * time.Second,
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
}
server, err := NewCryptoSetupTLSServer(
server, err := NewCryptoSetupServer(
sInitialStream,
sHandshakeStream,
protocol.ConnectionID{},

File diff suppressed because one or more lines are too long

View file

@ -1,48 +0,0 @@
package handshake
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
kexLifetime = protocol.EphermalKeyLifetime
kexCurrent crypto.KeyExchange
kexCurrentTime time.Time
kexMutex sync.RWMutex
)
// getEphermalKEX returns the currently active KEX, which changes every protocol.EphermalKeyLifetime
// See the explanation from the QUIC crypto doc:
//
// A single connection is the usual scope for forward security, but the security
// difference between an ephemeral key used for a single connection, and one
// used for all connections for 60 seconds is negligible. Thus we can amortise
// the Diffie-Hellman key generation at the server over all the connections in a
// small time span.
func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.RLock()
res := kexCurrent
t := kexCurrentTime
kexMutex.RUnlock()
if res != nil && time.Since(t) < kexLifetime {
return res, nil
}
kexMutex.Lock()
defer kexMutex.Unlock()
// Check if still unfulfilled
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
kex, err := crypto.NewCurve25519KEX()
if err != nil {
return nil, err
}
kexCurrent = kex
kexCurrentTime = time.Now()
return kexCurrent, nil
}
return kexCurrent, nil
}

View file

@ -1,35 +0,0 @@
package handshake
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Ephermal KEX", func() {
It("has a consistent KEX", func() {
kex1, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex1).ToNot(BeNil())
kex2, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex2).ToNot(BeNil())
Expect(kex1).To(Equal(kex2))
})
It("changes KEX", func() {
kexLifetime = 10 * time.Millisecond
defer func() {
kexLifetime = protocol.EphermalKeyLifetime
}()
kex, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex).ToNot(BeNil())
time.Sleep(kexLifetime)
kex2, err := getEphermalKEX()
Expect(err).ToNot(HaveOccurred())
Expect(kex2).ToNot(Equal(kex))
})
})

View file

@ -1,137 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"sort"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// A HandshakeMessage is a handshake message
type HandshakeMessage struct {
Tag Tag
Data map[Tag][]byte
}
var _ fmt.Stringer = &HandshakeMessage{}
// ParseHandshakeMessage reads a crypto message
func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) {
slice4 := make([]byte, 4)
if _, err := io.ReadFull(r, slice4); err != nil {
return HandshakeMessage{}, err
}
messageTag := Tag(binary.LittleEndian.Uint32(slice4))
if _, err := io.ReadFull(r, slice4); err != nil {
return HandshakeMessage{}, err
}
nPairs := binary.LittleEndian.Uint32(slice4)
if nPairs > protocol.CryptoMaxParams {
return HandshakeMessage{}, qerr.CryptoTooManyEntries
}
index := make([]byte, nPairs*8)
if _, err := io.ReadFull(r, index); err != nil {
return HandshakeMessage{}, err
}
resultMap := map[Tag][]byte{}
var dataStart uint32
for indexPos := 0; indexPos < int(nPairs)*8; indexPos += 8 {
tag := Tag(binary.LittleEndian.Uint32(index[indexPos : indexPos+4]))
dataEnd := binary.LittleEndian.Uint32(index[indexPos+4 : indexPos+8])
dataLen := dataEnd - dataStart
if dataLen > protocol.CryptoParameterMaxLength {
return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long")
}
data := make([]byte, dataLen)
if _, err := io.ReadFull(r, data); err != nil {
return HandshakeMessage{}, err
}
resultMap[tag] = data
dataStart = dataEnd
}
return HandshakeMessage{
Tag: messageTag,
Data: resultMap}, nil
}
// Write writes a crypto message
func (h HandshakeMessage) Write(b *bytes.Buffer) {
data := h.Data
utils.LittleEndian.WriteUint32(b, uint32(h.Tag))
utils.LittleEndian.WriteUint16(b, uint16(len(data)))
utils.LittleEndian.WriteUint16(b, 0)
// Save current position in the buffer, so that we can update the index in-place later
indexStart := b.Len()
indexData := make([]byte, 8*len(data))
b.Write(indexData) // Will be updated later
offset := uint32(0)
for i, t := range h.getTagsSorted() {
v := data[t]
b.Write(v)
offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
}
// Now we write the index data for real
copy(b.Bytes()[indexStart:], indexData)
}
func (h *HandshakeMessage) getTagsSorted() []Tag {
tags := make([]Tag, len(h.Data))
i := 0
for t := range h.Data {
tags[i] = t
i++
}
sort.Slice(tags, func(i, j int) bool {
return tags[i] < tags[j]
})
return tags
}
func (h HandshakeMessage) String() string {
var pad string
res := tagToString(h.Tag) + ":\n"
for _, tag := range h.getTagsSorted() {
if tag == TagPAD {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
} else {
res += fmt.Sprintf("\t%s: %#v\n", tagToString(tag), string(h.Data[tag]))
}
}
if len(pad) > 0 {
res += pad
}
return res
}
func tagToString(tag Tag) string {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, uint32(tag))
for i := range b {
if b[i] == 0 {
b[i] = ' '
}
}
return string(b)
}

View file

@ -1,71 +0,0 @@
package handshake
import (
"bytes"
"github.com/lucas-clemente/quic-go/qerr"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake Message", func() {
Context("when parsing", func() {
It("parses sample CHLO message", func() {
msg, err := ParseHandshakeMessage(bytes.NewReader(sampleCHLO))
Expect(err).ToNot(HaveOccurred())
Expect(msg.Tag).To(Equal(TagCHLO))
Expect(msg.Data).To(Equal(sampleCHLOMap))
})
It("rejects large numbers of pairs", func() {
r := bytes.NewReader([]byte("CHLO\xff\xff\xff\xff"))
_, err := ParseHandshakeMessage(r)
Expect(err).To(MatchError(qerr.CryptoTooManyEntries))
})
It("rejects too long values", func() {
r := bytes.NewReader([]byte{
'C', 'H', 'L', 'O',
1, 0, 0, 0,
0, 0, 0, 0,
0xff, 0xff, 0xff, 0xff,
})
_, err := ParseHandshakeMessage(r)
Expect(err).To(MatchError(qerr.Error(qerr.CryptoInvalidValueLength, "value too long")))
})
})
Context("when writing", func() {
It("writes sample message", func() {
b := &bytes.Buffer{}
HandshakeMessage{Tag: TagCHLO, Data: sampleCHLOMap}.Write(b)
Expect(b.Bytes()).To(Equal(sampleCHLO))
})
})
Context("string representation", func() {
It("has a string representation", func() {
str := HandshakeMessage{
Tag: TagSHLO,
Data: map[Tag][]byte{
TagAEAD: []byte("foobar"),
TagEXPY: []byte("raboof"),
},
}.String()
Expect(str[:4]).To(Equal("SHLO"))
Expect(str).To(ContainSubstring("AEAD: \"foobar\""))
Expect(str).To(ContainSubstring("EXPY: \"raboof\""))
})
It("lists padding separately", func() {
str := HandshakeMessage{
Tag: TagSHLO,
Data: map[Tag][]byte{
TagPAD: bytes.Repeat([]byte{0}, 1337),
},
}.String()
Expect(str).To(ContainSubstring("PAD"))
Expect(str).To(ContainSubstring("1337 bytes"))
})
})
})

View file

@ -25,28 +25,17 @@ type tlsExtensionHandler interface {
ReceivedExtensions(msgType uint8, exts []qtls.Extension) error
}
type baseCryptoSetup interface {
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake() error
io.Closer
HandleMessage([]byte, protocol.EncryptionLevel) bool
ConnectionState() ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
}
// CryptoSetup is the crypto setup used by gQUIC
type CryptoSetup interface {
baseCryptoSetup
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, protocol.EncryptionLevel, error)
}
// CryptoSetupTLS is the crypto setup used by IETF QUIC
type CryptoSetupTLS interface {
baseCryptoSetup
io.Closer
HandleMessage([]byte, protocol.EncryptionLevel) bool
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)

View file

@ -1,73 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"github.com/lucas-clemente/quic-go/internal/crypto"
)
// ServerConfig is a server config
type ServerConfig struct {
kex crypto.KeyExchange
certChain crypto.CertChain
ID []byte
obit []byte
cookieGenerator *CookieGenerator
}
// NewServerConfig creates a new server config
func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
id := make([]byte, 16)
_, err := rand.Read(id)
if err != nil {
return nil, err
}
obit := make([]byte, 8)
if _, err = rand.Read(obit); err != nil {
return nil, err
}
cookieGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
}
return &ServerConfig{
kex: kex,
certChain: certChain,
ID: id,
obit: obit,
cookieGenerator: cookieGenerator,
}, nil
}
// Get the server config binary representation
func (s *ServerConfig) Get() []byte {
var serverConfig bytes.Buffer
msg := HandshakeMessage{
Tag: TagSCFG,
Data: map[Tag][]byte{
TagSCID: s.ID,
TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
TagOBIT: s.obit,
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
},
}
msg.Write(&serverConfig)
return serverConfig.Bytes()
}
// Sign the server config and CHLO with the server's keyData
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
return s.certChain.SignServerProof(sni, chlo, s.Get())
}
// GetCertsCompressed returns the certificate data
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
}

View file

@ -1,184 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type serverConfigClient struct {
raw []byte
ID []byte
obit []byte
expiry time.Time
kex crypto.KeyExchange
sharedSecret []byte
}
var (
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
)
// parseServerConfig parses a server config
func parseServerConfig(data []byte) (*serverConfigClient, error) {
message, err := ParseHandshakeMessage(bytes.NewReader(data))
if err != nil {
return nil, err
}
if message.Tag != TagSCFG {
return nil, errMessageNotServerConfig
}
scfg := &serverConfigClient{raw: data}
err = scfg.parseValues(message.Data)
if err != nil {
return nil, err
}
return scfg, nil
}
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
// SCID
scfgID, ok := tagMap[TagSCID]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
}
if len(scfgID) != 16 {
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
}
s.ID = scfgID
// KEXS
// TODO: setup Key Exchange
kexs, ok := tagMap[TagKEXS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
}
if len(kexs)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
}
c255Foundat := -1
for i := 0; i < len(kexs)/4; i++ {
if bytes.Equal(kexs[4*i:4*i+4], []byte("C255")) {
c255Foundat = i
break
}
}
if c255Foundat < 0 {
return qerr.Error(qerr.CryptoNoSupport, "KEXS: Could not find C255, other key exchanges are not supported")
}
// AEAD
aead, ok := tagMap[TagAEAD]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
}
if len(aead)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
}
var aesgFound bool
for i := 0; i < len(aead)/4; i++ {
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
aesgFound = true
break
}
}
if !aesgFound {
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
}
// PUBS
pubs, ok := tagMap[TagPUBS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
var pubsKexs []struct {
Length uint32
Value []byte
}
var lastLen uint32
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
// the PUBS value is always prepended by 3 byte little endian length field
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
if err != nil {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
}
if lastLen == 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
if i+3+int(lastLen) > len(pubs) {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
pubsKexs = append(pubsKexs, struct {
Length uint32
Value []byte
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
}
if c255Foundat >= len(pubsKexs) {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
}
if pubsKexs[c255Foundat].Length != 32 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
var err error
s.kex, err = crypto.NewCurve25519KEX()
if err != nil {
return err
}
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
if err != nil {
return err
}
// OBIT
obit, ok := tagMap[TagOBIT]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
}
if len(obit) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
}
s.obit = obit
// EXPY
expy, ok := tagMap[TagEXPY]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
}
if len(expy) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
}
// make sure that the value doesn't overflow an int64
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
s.expiry = time.Unix(int64(expyTimestamp), 0)
// TODO: implement VER
return nil
}
func (s *serverConfigClient) IsExpired() bool {
return s.expiry.Before(time.Now())
}
func (s *serverConfigClient) Get() []byte {
return s.raw
}

View file

@ -1,266 +0,0 @@
package handshake
import (
"bytes"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// This tagMap can be passed to parseValues and is garantueed to not cause any errors
func getDefaultServerConfigClient() map[Tag][]byte {
return map[Tag][]byte{
TagSCID: bytes.Repeat([]byte{'F'}, 16),
TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, bytes.Repeat([]byte{0}, 32)...),
TagOBIT: bytes.Repeat([]byte{0}, 8),
TagEXPY: {0x0, 0x6c, 0x57, 0x78, 0, 0, 0, 0}, // 2033-12-24
}
}
var _ = Describe("Server Config", func() {
var tagMap map[Tag][]byte
BeforeEach(func() {
tagMap = getDefaultServerConfigClient()
})
It("returns the parsed server config", func() {
tagMap[TagSCID] = []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
b := &bytes.Buffer{}
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b)
scfg, err := parseServerConfig(b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.ID).To(Equal(tagMap[TagSCID]))
})
It("saves the raw server config", func() {
b := &bytes.Buffer{}
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b)
scfg, err := parseServerConfig(b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.raw).To(Equal(b.Bytes()))
})
It("tells if a server config is expired", func() {
scfg := &serverConfigClient{}
scfg.expiry = time.Now().Add(-time.Second)
Expect(scfg.IsExpired()).To(BeTrue())
scfg.expiry = time.Now().Add(time.Second)
Expect(scfg.IsExpired()).To(BeFalse())
})
Context("parsing the server config", func() {
It("rejects a handshake message with the wrong message tag", func() {
var serverConfig bytes.Buffer
HandshakeMessage{Tag: TagCHLO, Data: make(map[Tag][]byte)}.Write(&serverConfig)
_, err := parseServerConfig(serverConfig.Bytes())
Expect(err).To(MatchError(errMessageNotServerConfig))
})
It("errors on invalid handshake messages", func() {
var serverConfig bytes.Buffer
HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig)
_, err := parseServerConfig(serverConfig.Bytes()[:serverConfig.Len()-2])
Expect(err).To(MatchError("unexpected EOF"))
})
It("passes on errors encountered when reading the TagMap", func() {
var serverConfig bytes.Buffer
HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig)
_, err := parseServerConfig(serverConfig.Bytes())
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID"))
})
It("reads an example Handshake Message", func() {
var serverConfig bytes.Buffer
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(&serverConfig)
scfg, err := parseServerConfig(serverConfig.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.ID).To(Equal(tagMap[TagSCID]))
Expect(scfg.obit).To(Equal(tagMap[TagOBIT]))
})
})
Context("Reading values from the TagMap", func() {
var scfg *serverConfigClient
BeforeEach(func() {
scfg = &serverConfigClient{}
})
Context("ServerConfig ID", func() {
It("parses the ServerConfig ID", func() {
id := []byte{0xb2, 0xa4, 0xbb, 0x8f, 0xf6, 0x51, 0x28, 0xfd, 0x4d, 0xf7, 0xb3, 0x9a, 0x91, 0xe7, 0x91, 0xfb}
tagMap[TagSCID] = id
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
Expect(scfg.ID).To(Equal(id))
})
It("errors if the ServerConfig ID is missing", func() {
delete(tagMap, TagSCID)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID"))
})
It("rejects ServerConfig IDs that have the wrong length", func() {
tagMap[TagSCID] = bytes.Repeat([]byte{'F'}, 17) // 1 byte too long
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: SCID"))
})
})
Context("KEXS", func() {
It("rejects KEXS values that have the wrong length", func() {
tagMap[TagKEXS] = bytes.Repeat([]byte{'F'}, 5) // 1 byte too long
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: KEXS"))
})
It("rejects KEXS values other than C255", func() {
tagMap[TagKEXS] = []byte("P256")
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoNoSupport: KEXS: Could not find C255, other key exchanges are not supported"))
})
It("errors if the KEXS is missing", func() {
delete(tagMap, TagKEXS)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: KEXS"))
})
})
Context("AEAD", func() {
It("rejects AEAD values that have the wrong length", func() {
tagMap[TagAEAD] = bytes.Repeat([]byte{'F'}, 5) // 1 byte too long
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: AEAD"))
})
It("rejects AEAD values other than AESG", func() {
tagMap[TagAEAD] = []byte("S20P")
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoNoSupport: AEAD"))
})
It("recognizes AESG in the list of AEADs, at the first position", func() {
tagMap[TagAEAD] = []byte("AESGS20P")
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
})
It("recognizes AESG in the list of AEADs, not at the first position", func() {
tagMap[TagAEAD] = []byte("S20PAESG")
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
})
It("errors if the AEAD is missing", func() {
delete(tagMap, TagAEAD)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: AEAD"))
})
})
Context("PUBS", func() {
It("creates a Curve25519 key exchange", func() {
serverKex, err := crypto.NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
tagMap[TagPUBS] = append([]byte{0x20, 0x00, 0x00}, serverKex.PublicKey()...)
err = scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
sharedSecret, err := serverKex.CalculateSharedKey(scfg.kex.PublicKey())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.sharedSecret).To(Equal(sharedSecret))
})
It("rejects PUBS values that have the wrong length", func() {
tagMap[TagPUBS] = bytes.Repeat([]byte{'F'}, 100) // completely wrong length
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: PUBS"))
})
It("rejects PUBS values that have a zero length", func() {
tagMap[TagPUBS] = bytes.Repeat([]byte{0}, 100) // completely wrong length
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: PUBS"))
})
It("ensure that C255 Pubs must not be at the first index", func() {
serverKex, err := crypto.NewCurve25519KEX()
Expect(err).ToNot(HaveOccurred())
tagMap[TagKEXS] = []byte("P256C255") // have another KEXS before C255
// 3 byte len + 1 byte empty + C255
tagMap[TagPUBS] = append([]byte{0x01, 0x00, 0x00, 0x00}, append([]byte{0x20, 0x00, 0x00}, serverKex.PublicKey()...)...)
err = scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
sharedSecret, err := serverKex.CalculateSharedKey(scfg.kex.PublicKey())
Expect(err).ToNot(HaveOccurred())
Expect(scfg.sharedSecret).To(Equal(sharedSecret))
})
It("errors if the PUBS is missing", func() {
delete(tagMap, TagPUBS)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: PUBS"))
})
})
Context("OBIT", func() {
It("parses the OBIT value", func() {
obit := []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}
tagMap[TagOBIT] = obit
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
Expect(scfg.obit).To(Equal(obit))
})
It("errors if the OBIT is missing", func() {
delete(tagMap, TagOBIT)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: OBIT"))
})
It("rejets OBIT values that have the wrong length", func() {
tagMap[TagOBIT] = bytes.Repeat([]byte{'F'}, 7) // 1 byte too short
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: OBIT"))
})
})
Context("EXPY", func() {
It("parses the expiry date", func() {
tagMap[TagEXPY] = []byte{0xdc, 0x89, 0x0e, 0x59, 0, 0, 0, 0} // UNIX Timestamp 0x590e89dc = 1494125020
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
year, month, day := scfg.expiry.UTC().Date()
Expect(year).To(Equal(2017))
Expect(month).To(Equal(time.Month(5)))
Expect(day).To(Equal(7))
})
It("errors if the EXPY is missing", func() {
delete(tagMap, TagEXPY)
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoMessageParameterNotFound: EXPY"))
})
It("rejects EXPY values that have the wrong length", func() {
tagMap[TagEXPY] = bytes.Repeat([]byte{'F'}, 9) // 1 byte too long
err := scfg.parseValues(tagMap)
Expect(err).To(MatchError("CryptoInvalidValueLength: EXPY"))
})
It("deals with absurdly large timestamps", func() {
tagMap[TagEXPY] = bytes.Repeat([]byte{0xff}, 8) // this would overflow the int64
err := scfg.parseValues(tagMap)
Expect(err).ToNot(HaveOccurred())
Expect(scfg.expiry.After(time.Now())).To(BeTrue())
})
})
})
})

View file

@ -1,45 +0,0 @@
package handshake
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/crypto"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("ServerConfig", func() {
var (
kex crypto.KeyExchange
)
BeforeEach(func() {
var err error
kex, err = crypto.NewCurve25519KEX()
Expect(err).NotTo(HaveOccurred())
})
It("generates a random ID and OBIT", func() {
scfg1, err := NewServerConfig(kex, nil)
Expect(err).ToNot(HaveOccurred())
scfg2, err := NewServerConfig(kex, nil)
Expect(err).ToNot(HaveOccurred())
Expect(scfg1.ID).ToNot(Equal(scfg2.ID))
Expect(scfg1.obit).ToNot(Equal(scfg2.obit))
Expect(scfg1.cookieGenerator).ToNot(Equal(scfg2.cookieGenerator))
})
It("gets the proper binary representation", func() {
scfg, err := NewServerConfig(kex, nil)
Expect(err).NotTo(HaveOccurred())
expected := bytes.NewBuffer([]byte{0x53, 0x43, 0x46, 0x47, 0x6, 0x0, 0x0, 0x0, 0x41, 0x45, 0x41, 0x44, 0x4, 0x0, 0x0, 0x0, 0x53, 0x43, 0x49, 0x44, 0x14, 0x0, 0x0, 0x0, 0x50, 0x55, 0x42, 0x53, 0x37, 0x0, 0x0, 0x0, 0x4b, 0x45, 0x58, 0x53, 0x3b, 0x0, 0x0, 0x0, 0x4f, 0x42, 0x49, 0x54, 0x43, 0x0, 0x0, 0x0, 0x45, 0x58, 0x50, 0x59, 0x4b, 0x0, 0x0, 0x0, 0x41, 0x45, 0x53, 0x47})
expected.Write(scfg.ID)
expected.Write([]byte{0x20, 0x0, 0x0})
expected.Write(kex.PublicKey())
expected.Write([]byte{0x43, 0x32, 0x35, 0x35})
expected.Write(scfg.obit)
expected.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
Expect(scfg.Get()).To(Equal(expected.Bytes()))
})
})

View file

@ -1,93 +0,0 @@
package handshake
// A Tag in the QUIC crypto
type Tag uint32
const (
// TagCHLO is a client hello
TagCHLO Tag = 'C' + 'H'<<8 + 'L'<<16 + 'O'<<24
// TagREJ is a server hello rejection
TagREJ Tag = 'R' + 'E'<<8 + 'J'<<16
// TagSCFG is a server config
TagSCFG Tag = 'S' + 'C'<<8 + 'F'<<16 + 'G'<<24
// TagPAD is padding
TagPAD Tag = 'P' + 'A'<<8 + 'D'<<16
// TagSNI is the server name indication
TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16
// TagVER is the QUIC version
TagVER Tag = 'V' + 'E'<<8 + 'R'<<16
// TagCCS are the hashes of the common certificate sets
TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16
// TagCCRT are the hashes of the cached certificates
TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24
// TagMSPC is max streams per connection
TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24
// TagMIDS is max incoming dyanamic streams
TagMIDS Tag = 'M' + 'I'<<8 + 'D'<<16 + 'S'<<24
// TagUAID is the user agent ID
TagUAID Tag = 'U' + 'A'<<8 + 'I'<<16 + 'D'<<24
// TagSVID is the server ID (unofficial tag by us :)
TagSVID Tag = 'S' + 'V'<<8 + 'I'<<16 + 'D'<<24
// TagTCID is truncation of the connection ID
TagTCID Tag = 'T' + 'C'<<8 + 'I'<<16 + 'D'<<24
// TagPDMD is the proof demand
TagPDMD Tag = 'P' + 'D'<<8 + 'M'<<16 + 'D'<<24
// TagSRBF is the socket receive buffer
TagSRBF Tag = 'S' + 'R'<<8 + 'B'<<16 + 'F'<<24
// TagICSL is the idle connection state lifetime
TagICSL Tag = 'I' + 'C'<<8 + 'S'<<16 + 'L'<<24
// TagNONP is the client proof nonce
TagNONP Tag = 'N' + 'O'<<8 + 'N'<<16 + 'P'<<24
// TagSCLS is the silently close timeout
TagSCLS Tag = 'S' + 'C'<<8 + 'L'<<16 + 'S'<<24
// TagCSCT is the signed cert timestamp (RFC6962) of leaf cert
TagCSCT Tag = 'C' + 'S'<<8 + 'C'<<16 + 'T'<<24
// TagCOPT are the connection options
TagCOPT Tag = 'C' + 'O'<<8 + 'P'<<16 + 'T'<<24
// TagCFCW is the initial session/connection flow control receive window
TagCFCW Tag = 'C' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagSFCW is the initial stream flow control receive window.
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagNSTP is the no STOP_WAITING experiment
// currently unsupported by quic-go
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
// TagSTK is the source-address token
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
// TagSNO is the server nonce
TagSNO Tag = 'S' + 'N'<<8 + 'O'<<16
// TagPROF is the server proof
TagPROF Tag = 'P' + 'R'<<8 + 'O'<<16 + 'F'<<24
// TagNONC is the client nonce
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
// TagXLCT is the expected leaf certificate
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
// TagSCID is the server config ID
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
// TagKEXS is the list of key exchange algos
TagKEXS Tag = 'K' + 'E'<<8 + 'X'<<16 + 'S'<<24
// TagAEAD is the list of AEAD algos
TagAEAD Tag = 'A' + 'E'<<8 + 'A'<<16 + 'D'<<24
// TagPUBS is the public value for the KEX
TagPUBS Tag = 'P' + 'U'<<8 + 'B'<<16 + 'S'<<24
// TagOBIT is the client orbit
TagOBIT Tag = 'O' + 'B'<<8 + 'I'<<16 + 'T'<<24
// TagEXPY is the server config expiry
TagEXPY Tag = 'E' + 'X'<<8 + 'P'<<16 + 'Y'<<24
// TagCERT is the CERT data
TagCERT Tag = 0xff545243
// TagSHLO is the server hello
TagSHLO Tag = 'S' + 'H'<<8 + 'L'<<16 + 'O'<<24
// TagPRST is the public reset tag
TagPRST Tag = 'P' + 'R'<<8 + 'S'<<16 + 'T'<<24
// TagRSEQ is the public reset rejected packet number
TagRSEQ Tag = 'R' + 'S'<<8 + 'E'<<16 + 'Q'<<24
// TagRNON is the public reset nonce
TagRNON Tag = 'R' + 'N'<<8 + 'O'<<16 + 'N'<<24
)

View file

@ -17,14 +17,15 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
handler *extensionHandlerClient
paramsChan <-chan TransportParameters
)
version := protocol.VersionNumber(0x42)
BeforeEach(func() {
var h tlsExtensionHandler
h, paramsChan = newExtensionHandlerClient(
&TransportParameters{},
protocol.VersionWhatever,
version,
nil,
protocol.VersionWhatever,
version,
utils.DefaultLogger,
)
handler = h.(*extensionHandlerClient)
@ -57,6 +58,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
Type: quicTLSExtensionType,
Data: (&encryptedExtensionsTransportParameters{
Parameters: params,
NegotiatedVersion: version,
SupportedVersions: []protocol.VersionNumber{handler.version},
}).Marshal(),
}

View file

@ -12,260 +12,164 @@ import (
)
var _ = Describe("Transport Parameters", func() {
Context("for gQUIC", func() {
Context("parsing", func() {
It("sets all values", func() {
values := map[Tag][]byte{
TagSFCW: {0xad, 0xfb, 0xca, 0xde},
TagCFCW: {0xef, 0xbe, 0xad, 0xde},
TagICSL: {0x0d, 0xf0, 0xad, 0xba},
TagMIDS: {0xff, 0x10, 0x00, 0xc0},
}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xdecafbad)))
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0xdeadbeef)))
Expect(params.IdleTimeout).To(Equal(time.Duration(0xbaadf00d) * time.Second))
Expect(params.MaxStreams).To(Equal(uint32(0xc00010ff)))
Expect(params.OmitConnectionID).To(BeFalse())
})
It("has a string representation", func() {
p := &TransportParameters{
StreamFlowControlWindow: 0x1234,
ConnectionFlowControlWindow: 0x4321,
MaxBidiStreams: 1337,
MaxUniStreams: 7331,
IdleTimeout: 42 * time.Second,
}
Expect(p.String()).To(Equal("&handshake.TransportParameters{StreamFlowControlWindow: 0x1234, ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}"))
})
It("reads if the connection ID should be omitted", func() {
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.OmitConnectionID).To(BeTrue())
})
Context("parsing", func() {
var (
params *TransportParameters
parameters map[transportParameterID][]byte
statelessResetToken []byte
)
It("doesn't allow idle timeouts below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
values := map[Tag][]byte{
TagICSL: {uint8(t.Seconds()), 0, 0, 0},
}
params, err := readHelloMap(values)
Expect(err).ToNot(HaveOccurred())
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
})
marshal := func(p map[transportParameterID][]byte) []byte {
b := &bytes.Buffer{}
for id, val := range p {
utils.BigEndian.WriteUint16(b, uint16(id))
utils.BigEndian.WriteUint16(b, uint16(len(val)))
b.Write(val)
}
return b.Bytes()
}
It("errors when given an invalid SFCW value", func() {
values := map[Tag][]byte{TagSFCW: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid CFCW value", func() {
values := map[Tag][]byte{TagCFCW: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid TCID value", func() {
values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid ICSL value", func() {
values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
It("errors when given an invalid MIDS value", func() {
values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short
_, err := readHelloMap(values)
Expect(err).To(MatchError(errMalformedTag))
})
BeforeEach(func() {
params = &TransportParameters{}
statelessResetToken = bytes.Repeat([]byte{42}, 16)
parameters = map[transportParameterID][]byte{
initialMaxStreamDataParameterID: {0x11, 0x22, 0x33, 0x44},
initialMaxDataParameterID: {0x22, 0x33, 0x44, 0x55},
initialMaxBidiStreamsParameterID: {0x33, 0x44},
initialMaxUniStreamsParameterID: {0x44, 0x55},
idleTimeoutParameterID: {0x13, 0x37},
maxPacketSizeParameterID: {0x73, 0x31},
disableMigrationParameterID: {},
statelessResetTokenParameterID: statelessResetToken,
}
})
It("reads parameters", func() {
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344)))
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455)))
Expect(params.MaxBidiStreams).To(Equal(uint16(0x3344)))
Expect(params.MaxUniStreams).To(Equal(uint16(0x4455)))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
Expect(params.MaxPacketSize).To(Equal(protocol.ByteCount(0x7331)))
Expect(params.DisableMigration).To(BeTrue())
Expect(params.StatelessResetToken).To(Equal(statelessResetToken))
})
Context("writing", func() {
It("returns all necessary parameters ", func() {
params := &TransportParameters{
StreamFlowControlWindow: 0xdeadbeef,
ConnectionFlowControlWindow: 0xdecafbad,
IdleTimeout: 0xbaaaaaad * time.Second,
MaxStreams: 0x1337,
}
entryMap := params.getHelloMap()
Expect(entryMap).To(HaveLen(4))
Expect(entryMap).ToNot(HaveKey(TagTCID))
Expect(entryMap).To(HaveKeyWithValue(TagSFCW, []byte{0xef, 0xbe, 0xad, 0xde}))
Expect(entryMap).To(HaveKeyWithValue(TagCFCW, []byte{0xad, 0xfb, 0xca, 0xde}))
Expect(entryMap).To(HaveKeyWithValue(TagICSL, []byte{0xad, 0xaa, 0xaa, 0xba}))
Expect(entryMap).To(HaveKeyWithValue(TagMIDS, []byte{0x37, 0x13, 0, 0}))
})
It("errors if a parameter is sent twice", func() {
data := marshal(parameters)
parameters = map[transportParameterID][]byte{
maxPacketSizeParameterID: {0x73, 0x31},
}
data = append(data, marshal(parameters)...)
err := params.unmarshal(data)
Expect(err).To(MatchError(fmt.Sprintf("received duplicate transport parameter %#x", maxPacketSizeParameterID)))
})
It("requests omission of the connection ID", func() {
params := &TransportParameters{OmitConnectionID: true}
entryMap := params.getHelloMap()
Expect(entryMap).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
})
It("doesn't allow values below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
})
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_data has the wrong length", func() {
parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
parameters[initialMaxBidiStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_id_bidi: 3 (expected 2)"))
})
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
parameters[initialMaxUniStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_id_uni: 3 (expected 2)"))
})
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
})
It("rejects the parameters if max_packet_size has the wrong length", func() {
parameters[maxPacketSizeParameterID] = []byte{0x11} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for max_packet_size: 1 (expected 2)"))
})
It("rejects max_packet_sizes smaller than 1200 bytes", func() {
parameters[maxPacketSizeParameterID] = []byte{0x4, 0xaf} // 0x4af = 1199
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("invalid value for max_packet_size: 1199 (minimum 1200)"))
})
It("rejects the parameters if disable_connection_migration has the wrong length", func() {
parameters[disableMigrationParameterID] = []byte{0x11} // should empty
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for disable_migration: 1 (expected empty)"))
})
It("rejects the parameters if the stateless_reset_token has the wrong length", func() {
parameters[statelessResetTokenParameterID] = statelessResetToken[1:]
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)"))
})
It("ignores unknown parameters", func() {
parameters[1337] = []byte{42}
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
})
})
Context("for TLS", func() {
It("has a string representation", func() {
p := &TransportParameters{
StreamFlowControlWindow: 0x1234,
ConnectionFlowControlWindow: 0x4321,
MaxBidiStreams: 1337,
MaxUniStreams: 7331,
IdleTimeout: 42 * time.Second,
Context("marshalling", func() {
It("marshals", func() {
params := &TransportParameters{
StreamFlowControlWindow: 0xdeadbeef,
ConnectionFlowControlWindow: 0xdecafbad,
IdleTimeout: 0xcafe * time.Second,
MaxBidiStreams: 0x1234,
MaxUniStreams: 0x4321,
DisableMigration: true,
StatelessResetToken: bytes.Repeat([]byte{100}, 16),
}
Expect(p.String()).To(Equal("&handshake.TransportParameters{StreamFlowControlWindow: 0x1234, ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}"))
})
b := &bytes.Buffer{}
params.marshal(b)
Context("parsing", func() {
var (
params *TransportParameters
parameters map[transportParameterID][]byte
statelessResetToken []byte
)
marshal := func(p map[transportParameterID][]byte) []byte {
b := &bytes.Buffer{}
for id, val := range p {
utils.BigEndian.WriteUint16(b, uint16(id))
utils.BigEndian.WriteUint16(b, uint16(len(val)))
b.Write(val)
}
return b.Bytes()
}
BeforeEach(func() {
params = &TransportParameters{}
statelessResetToken = bytes.Repeat([]byte{42}, 16)
parameters = map[transportParameterID][]byte{
initialMaxStreamDataParameterID: {0x11, 0x22, 0x33, 0x44},
initialMaxDataParameterID: {0x22, 0x33, 0x44, 0x55},
initialMaxBidiStreamsParameterID: {0x33, 0x44},
initialMaxUniStreamsParameterID: {0x44, 0x55},
idleTimeoutParameterID: {0x13, 0x37},
maxPacketSizeParameterID: {0x73, 0x31},
disableMigrationParameterID: {},
statelessResetTokenParameterID: statelessResetToken,
}
})
It("reads parameters", func() {
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344)))
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455)))
Expect(params.MaxBidiStreams).To(Equal(uint16(0x3344)))
Expect(params.MaxUniStreams).To(Equal(uint16(0x4455)))
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
Expect(params.OmitConnectionID).To(BeFalse())
Expect(params.MaxPacketSize).To(Equal(protocol.ByteCount(0x7331)))
Expect(params.DisableMigration).To(BeTrue())
Expect(params.StatelessResetToken).To(Equal(statelessResetToken))
})
It("errors if a parameter is sent twice", func() {
data := marshal(parameters)
parameters = map[transportParameterID][]byte{
maxPacketSizeParameterID: {0x73, 0x31},
}
data = append(data, marshal(parameters)...)
err := params.unmarshal(data)
Expect(err).To(MatchError(fmt.Sprintf("received duplicate transport parameter %#x", maxPacketSizeParameterID)))
})
It("doesn't allow values below the minimum remote idle timeout", func() {
t := 2 * time.Second
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
})
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_data has the wrong length", func() {
parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
})
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
parameters[initialMaxBidiStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_id_bidi: 3 (expected 2)"))
})
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
parameters[initialMaxUniStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for initial_max_stream_id_uni: 3 (expected 2)"))
})
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
})
It("rejects the parameters if max_packet_size has the wrong length", func() {
parameters[maxPacketSizeParameterID] = []byte{0x11} // should be 2 bytes
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for max_packet_size: 1 (expected 2)"))
})
It("rejects max_packet_sizes smaller than 1200 bytes", func() {
parameters[maxPacketSizeParameterID] = []byte{0x4, 0xaf} // 0x4af = 1199
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("invalid value for max_packet_size: 1199 (minimum 1200)"))
})
It("rejects the parameters if disable_connection_migration has the wrong length", func() {
parameters[disableMigrationParameterID] = []byte{0x11} // should empty
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for disable_migration: 1 (expected empty)"))
})
It("rejects the parameters if the stateless_reset_token has the wrong length", func() {
parameters[statelessResetTokenParameterID] = statelessResetToken[1:]
err := params.unmarshal(marshal(parameters))
Expect(err).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)"))
})
It("ignores unknown parameters", func() {
parameters[1337] = []byte{42}
err := params.unmarshal(marshal(parameters))
Expect(err).ToNot(HaveOccurred())
})
})
Context("marshalling", func() {
It("marshals", func() {
params := &TransportParameters{
StreamFlowControlWindow: 0xdeadbeef,
ConnectionFlowControlWindow: 0xdecafbad,
IdleTimeout: 0xcafe * time.Second,
MaxBidiStreams: 0x1234,
MaxUniStreams: 0x4321,
DisableMigration: true,
StatelessResetToken: bytes.Repeat([]byte{100}, 16),
}
b := &bytes.Buffer{}
params.marshal(b)
p := &TransportParameters{}
Expect(p.unmarshal(b.Bytes())).To(Succeed())
Expect(p.StreamFlowControlWindow).To(Equal(params.StreamFlowControlWindow))
Expect(p.ConnectionFlowControlWindow).To(Equal(params.ConnectionFlowControlWindow))
Expect(p.MaxUniStreams).To(Equal(params.MaxUniStreams))
Expect(p.MaxBidiStreams).To(Equal(params.MaxBidiStreams))
Expect(p.IdleTimeout).To(Equal(params.IdleTimeout))
Expect(p.DisableMigration).To(Equal(params.DisableMigration))
Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken))
})
p := &TransportParameters{}
Expect(p.unmarshal(b.Bytes())).To(Succeed())
Expect(p.StreamFlowControlWindow).To(Equal(params.StreamFlowControlWindow))
Expect(p.ConnectionFlowControlWindow).To(Equal(params.ConnectionFlowControlWindow))
Expect(p.MaxUniStreams).To(Equal(params.MaxUniStreams))
Expect(p.MaxBidiStreams).To(Equal(params.MaxBidiStreams))
Expect(p.IdleTimeout).To(Equal(params.IdleTimeout))
Expect(p.DisableMigration).To(Equal(params.DisableMigration))
Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken))
})
})
})

View file

@ -9,12 +9,8 @@ import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// errMalformedTag is returned when the tag value cannot be read
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
StreamFlowControlWindow protocol.ByteCount
@ -22,78 +18,12 @@ type TransportParameters struct {
MaxPacketSize protocol.ByteCount
MaxUniStreams uint16 // only used for IETF QUIC
MaxBidiStreams uint16 // only used for IETF QUIC
MaxStreams uint32 // only used for gQUIC
MaxUniStreams uint16
MaxBidiStreams uint16
OmitConnectionID bool // only used for gQUIC
IdleTimeout time.Duration
DisableMigration bool // only used for IETF QUIC
StatelessResetToken []byte // only used for IETF QUIC
}
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
params := &TransportParameters{}
if value, ok := tags[TagTCID]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.OmitConnectionID = (v == 0)
}
if value, ok := tags[TagMIDS]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.MaxStreams = v
}
if value, ok := tags[TagICSL]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
}
if value, ok := tags[TagSFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.StreamFlowControlWindow = protocol.ByteCount(v)
}
if value, ok := tags[TagCFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
}
return params, nil
}
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
sfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
cfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
mids := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
icsl := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
tags := map[Tag][]byte{
TagICSL: icsl.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(),
}
if p.OmitConnectionID {
tags[TagTCID] = []byte{0, 0, 0, 0}
}
return tags
DisableMigration bool
StatelessResetToken []byte
}
func (p *TransportParameters) unmarshal(data []byte) error {
@ -209,7 +139,6 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) {
}
// String returns a string representation, intended for logging.
// It should only used for IETF QUIC.
func (p *TransportParameters) String() string {
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
}

View file

@ -98,18 +98,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetPacketNumberLen(arg0 interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPacketNumberLen", reflect.TypeOf((*MockSentPacketHandler)(nil).GetPacketNumberLen), arg0)
}
// GetStopWaitingFrame mocks base method
func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame {
ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0)
ret0, _ := ret[0].(*wire.StopWaitingFrame)
return ret0
}
// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame
func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockSentPacketHandler)(nil).GetStopWaitingFrame), arg0)
}
// OnAlarm mocks base method
func (m *MockSentPacketHandler) OnAlarm() error {
ret := m.ctrl.Call(m, "OnAlarm")

View file

@ -0,0 +1,149 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: CryptoSetup)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
)
// MockCryptoSetup is a mock of CryptoSetup interface
type MockCryptoSetup struct {
ctrl *gomock.Controller
recorder *MockCryptoSetupMockRecorder
}
// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup
type MockCryptoSetupMockRecorder struct {
mock *MockCryptoSetup
}
// NewMockCryptoSetup creates a new mock instance
func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup {
mock := &MockCryptoSetup{ctrl: ctrl}
mock.recorder = &MockCryptoSetupMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder {
return m.recorder
}
// Close mocks base method
func (m *MockCryptoSetup) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close))
}
// ConnectionState mocks base method
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState {
ret := m.ctrl.Call(m, "ConnectionState")
ret0, _ := ret[0].(handshake.ConnectionState)
return ret0
}
// ConnectionState indicates an expected call of ConnectionState
func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
}
// GetSealer mocks base method
func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
ret := m.ctrl.Call(m, "GetSealer")
ret0, _ := ret[0].(protocol.EncryptionLevel)
ret1, _ := ret[1].(handshake.Sealer)
return ret0, ret1
}
// GetSealer indicates an expected call of GetSealer
func (mr *MockCryptoSetupMockRecorder) GetSealer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealer))
}
// GetSealerWithEncryptionLevel mocks base method
func (m *MockCryptoSetup) GetSealerWithEncryptionLevel(arg0 protocol.EncryptionLevel) (handshake.Sealer, error) {
ret := m.ctrl.Call(m, "GetSealerWithEncryptionLevel", arg0)
ret0, _ := ret[0].(handshake.Sealer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSealerWithEncryptionLevel indicates an expected call of GetSealerWithEncryptionLevel
func (mr *MockCryptoSetupMockRecorder) GetSealerWithEncryptionLevel(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerWithEncryptionLevel", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealerWithEncryptionLevel), arg0)
}
// HandleMessage mocks base method
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// HandleMessage indicates an expected call of HandleMessage
func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
}
// Open1RTT mocks base method
func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open1RTT indicates an expected call of Open1RTT
func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3)
}
// OpenHandshake mocks base method
func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenHandshake indicates an expected call of OpenHandshake
func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
}
// OpenInitial mocks base method
func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenInitial indicates an expected call of OpenInitial
func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3)
}
// RunHandshake mocks base method
func (m *MockCryptoSetup) RunHandshake() error {
ret := m.ctrl.Call(m, "RunHandshake")
ret0, _ := ret[0].(error)
return ret0
}
// RunHandshake indicates an expected call of RunHandshake
func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
}

View file

@ -1,6 +1,7 @@
package mocks
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"

View file

@ -7,30 +7,16 @@ type EncryptionLevel int
const (
// EncryptionUnspecified is a not specified encryption level
EncryptionUnspecified EncryptionLevel = iota
// EncryptionUnencrypted is not encrypted, for gQUIC
EncryptionUnencrypted
// EncryptionInitial is the Initial encryption level
EncryptionInitial
// EncryptionSecure is encrypted, but not forward secure
EncryptionSecure
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake
// EncryptionForwardSecure is forward secure
EncryptionForwardSecure
// Encryption1RTT is the 1-RTT encryption level
Encryption1RTT
)
func (e EncryptionLevel) String() string {
switch e {
// gQUIC
case EncryptionUnencrypted:
return "unencrypted"
case EncryptionSecure:
return "encrypted (not forward-secure)"
case EncryptionForwardSecure:
return "forward-secure"
// IETF QUIC
case EncryptionInitial:
return "Initial"
case EncryptionHandshake:

View file

@ -8,9 +8,6 @@ import (
var _ = Describe("Encryption Level", func() {
It("has the correct string representation", func() {
Expect(EncryptionUnspecified.String()).To(Equal("unknown"))
Expect(EncryptionUnencrypted.String()).To(Equal("unencrypted"))
Expect(EncryptionSecure.String()).To(Equal("encrypted (not forward-secure)"))
Expect(EncryptionForwardSecure.String()).To(Equal("forward-secure"))
Expect(EncryptionInitial.String()).To(Equal("Initial"))
Expect(EncryptionHandshake.String()).To(Equal("Handshake"))
Expect(Encryption1RTT.String()).To(Equal("1-RTT"))

View file

@ -8,17 +8,13 @@ func InferPacketNumber(
version VersionNumber,
) PacketNumber {
var epochDelta PacketNumber
if version.UsesVarintPacketNumbers() {
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30
}
} else {
epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8)
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30
}
epoch := lastPacketNumber & ^(epochDelta - 1)
prevEpochBegin := epoch - epochDelta
@ -48,8 +44,7 @@ func delta(a, b PacketNumber) PacketNumber {
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked)
if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) ||
!version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) {
if diff < (1 << (14 - 1)) {
return PacketNumberLen2
}
return PacketNumberLen4

View file

@ -12,17 +12,15 @@ import (
var _ = Describe("packet number calculation", func() {
Context("infering a packet number", func() {
getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 {
if v.UsesVarintPacketNumbers() {
switch len {
case PacketNumberLen1:
return uint64(1) << 7
case PacketNumberLen2:
return uint64(1) << 14
case PacketNumberLen4:
return uint64(1) << 30
default:
Fail("invalid packet number len")
}
switch len {
case PacketNumberLen1:
return uint64(1) << 7
case PacketNumberLen2:
return uint64(1) << 14
case PacketNumberLen4:
return uint64(1) << 30
default:
Fail("invalid packet number len")
}
return uint64(1) << (len * 8)
}
@ -33,134 +31,128 @@ var _ = Describe("packet number calculation", func() {
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected)))
}
for _, v := range []VersionNumber{Version39, VersionTLS} {
version := v
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
length := l
Context(fmt.Sprintf("using varint packet numbers: %t", version.UsesVarintPacketNumbers()), func() {
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
length := l
Context(fmt.Sprintf("with %d bytes", length), func() {
epoch := getEpoch(length, VersionWhatever)
epochMask := epoch - 1
Context(fmt.Sprintf("with %d bytes", length), func() {
epoch := getEpoch(length, version)
epochMask := epoch - 1
It("works near epoch start", func() {
// A few quick manual sanity check
check(length, 1, 0, VersionWhatever)
check(length, epoch+1, epochMask, VersionWhatever)
check(length, epoch, epochMask, VersionWhatever)
It("works near epoch start", func() {
// A few quick manual sanity check
check(length, 1, 0, version)
check(length, epoch+1, epochMask, version)
check(length, epoch, epochMask, version)
// Cases where the last number was close to the start of the range.
for last := uint64(0); last < 10; last++ {
// Small numbers should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, j, last, VersionWhatever)
}
// Cases where the last number was close to the start of the range.
for last := uint64(0); last < 10; last++ {
// Small numbers should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, j, last, version)
}
// Large numbers should not wrap either (because we're near 0 already).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, VersionWhatever)
}
}
})
// Large numbers should not wrap either (because we're near 0 already).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, version)
}
}
})
It("works near epoch end", func() {
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := epoch - i
It("works near epoch end", func() {
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := epoch - i
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, epoch+j, last, VersionWhatever)
}
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, epoch+j, last, version)
}
// Large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, VersionWhatever)
}
}
})
// Large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, epoch-1-j, last, version)
}
}
})
// Next check where we're in a non-zero epoch to verify we handle
// reverse wrapping, too.
It("works near previous epoch", func() {
prevEpoch := 1 * epoch
curEpoch := 2 * epoch
// Cases where the last number was close to the start of the range
for i := uint64(0); i < 10; i++ {
last := curEpoch + i
// Small number should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, curEpoch+j, last, VersionWhatever)
}
// Next check where we're in a non-zero epoch to verify we handle
// reverse wrapping, too.
It("works near previous epoch", func() {
prevEpoch := 1 * epoch
curEpoch := 2 * epoch
// Cases where the last number was close to the start of the range
for i := uint64(0); i < 10; i++ {
last := curEpoch + i
// Small number should not wrap (even if they're out of order).
for j := uint64(0); j < 10; j++ {
check(length, curEpoch+j, last, version)
}
// But large numbers should reverse wrap.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, prevEpoch+num, last, VersionWhatever)
}
}
})
// But large numbers should reverse wrap.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, prevEpoch+num, last, version)
}
}
})
It("works near next epoch", func() {
curEpoch := 2 * epoch
nextEpoch := 3 * epoch
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := nextEpoch - 1 - i
It("works near next epoch", func() {
curEpoch := 2 * epoch
nextEpoch := 3 * epoch
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
last := nextEpoch - 1 - i
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, nextEpoch+j, last, VersionWhatever)
}
// Small numbers should wrap.
for j := uint64(0); j < 10; j++ {
check(length, nextEpoch+j, last, version)
}
// but large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, curEpoch+num, last, VersionWhatever)
}
}
})
// but large numbers should not (even if they're out of order).
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, curEpoch+num, last, version)
}
}
})
It("works near next max", func() {
maxNumber := uint64(math.MaxUint64)
maxEpoch := maxNumber & ^epochMask
It("works near next max", func() {
maxNumber := uint64(math.MaxUint64)
maxEpoch := maxNumber & ^epochMask
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
// Subtract 1, because the expected next packet number is 1 more than the
// last packet number.
last := maxNumber - i - 1
// Cases where the last number was close to the end of the range
for i := uint64(0); i < 10; i++ {
// Subtract 1, because the expected next packet number is 1 more than the
// last packet number.
last := maxNumber - i - 1
// Small numbers should not wrap, because they have nowhere to go.
for j := uint64(0); j < 10; j++ {
check(length, maxEpoch+j, last, VersionWhatever)
}
// Small numbers should not wrap, because they have nowhere to go.
for j := uint64(0); j < 10; j++ {
check(length, maxEpoch+j, last, version)
}
// Large numbers should not wrap either.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, maxEpoch+num, last, version)
}
}
})
})
}
// Large numbers should not wrap either.
for j := uint64(0); j < 10; j++ {
num := epoch - 1 - j
check(length, maxEpoch+num, last, VersionWhatever)
}
}
})
Context("shortening a packet number for the header", func() {
Context("shortening", func() {
It("sends out low packet numbers as 2 byte", func() {
length := GetPacketNumberLengthForHeader(4, 2, version)
length := GetPacketNumberLengthForHeader(4, 2, VersionWhatever)
Expect(length).To(Equal(PacketNumberLen2))
})
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, version)
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, VersionWhatever)
Expect(length).To(Equal(PacketNumberLen2))
})
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
length := GetPacketNumberLengthForHeader(40000, 2, version)
length := GetPacketNumberLengthForHeader(40000, 2, VersionWhatever)
Expect(length).To(Equal(PacketNumberLen4))
})
})
@ -170,10 +162,10 @@ var _ = Describe("packet number calculation", func() {
for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
@ -182,28 +174,28 @@ var _ = Describe("packet number calculation", func() {
for i := uint64(1); i < 10000; i++ {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i / 2)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
epochMask := getEpoch(length, version) - 1
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
epochMask := getEpoch(length, VersionWhatever) - 1
wirePacketNumber := uint64(packetNumber) & epochMask
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})
It("also works for larger packet numbers", func() {
var increment uint64
for i := uint64(1); i < getEpoch(PacketNumberLen4, version); i += increment {
for i := uint64(1); i < getEpoch(PacketNumberLen4, VersionWhatever); i += increment {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(1)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
epochMask := getEpoch(length, version) - 1
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
epochMask := getEpoch(length, VersionWhatever) - 1
wirePacketNumber := uint64(packetNumber) & epochMask
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
Expect(inferedPacketNumber).To(Equal(packetNumber))
increment = getEpoch(length, version) / 8
increment = getEpoch(length, VersionWhatever) / 8
}
})
@ -211,10 +203,10 @@ var _ = Describe("packet number calculation", func() {
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
packetNumber := PacketNumber(i)
leastUnacked := PacketNumber(i - 1000)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
Expect(inferedPacketNumber).To(Equal(packetNumber))
}
})

View file

@ -23,7 +23,7 @@ const (
PacketNumberLen6 PacketNumberLen = 6
)
// The PacketType is the Long Header Type (only used for the IETF draft header format)
// The PacketType is the Long Header Type
type PacketType uint8
const (
@ -71,10 +71,7 @@ const MaxReceivePacketSize ByteCount = 1452
// Used in QUIC for congestion window computations in bytes.
const DefaultTCPMSS ByteCount = 1460
// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC)
const MinClientHelloSize = 1024
// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is required to have.
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200
// MaxClientHellos is the maximum number of times we'll send a client hello
@ -83,8 +80,5 @@ const MinInitialPacketSize = 1200
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets.
const ConnectionIDLenGQUIC = 8
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8

View file

@ -24,10 +24,6 @@ const InitialCongestionWindow ByteCount = 32 * DefaultTCPMSS
// session queues for later until it sends a public reset.
const MaxUndecryptablePackets = 10
// PublicResetTimeout is the time to wait before sending a Public Reset when receiving too many undecryptable packets during the handshake
// This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto
const PublicResetTimeout = 500 * time.Millisecond
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
// This is the value that Google servers are using
const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB
@ -65,12 +61,6 @@ const DefaultMaxIncomingStreams = 100
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
const DefaultMaxIncomingUniStreams = 100
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
const MaxStreamsMultiplier = 1.1
// MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used.
const MaxStreamsMinimumIncrement = 10
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets
@ -103,20 +93,10 @@ const MaxNonRetransmittableAcks = 19
// prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000
// CryptoMaxParams is the upper limit for the number of parameters in a crypto message.
// Value taken from Chrome.
const CryptoMaxParams = 128
// CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message.
const CryptoParameterMaxLength = 4000
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
// This limits the size of the ClientHello and Certificates that can be received.
const MaxCryptoStreamOffset = 16 * (1 << 10)
// EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX.
const EphermalKeyLifetime = time.Minute
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
const MinRemoteIdleTimeout = 5 * time.Second
@ -130,16 +110,13 @@ const DefaultHandshakeTimeout = 10 * time.Second
// after this time all information about the old connection will be deleted
const ClosedSessionDeleteTimeout = time.Minute
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
const NumCachedCertificates = 128
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
// 1. it reduces the framing overhead
// 2. it reduces the head-of-line blocking, when a packet is lost
const MinStreamFrameSize ByteCount = 128
// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write
// MaxAckFrameSize is the maximum size for an ACK frame that we write
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
// The MaxAckFrameSize should be large enough to encode many ACK range,
// but must ensure that a maximum size ACK frame fits into one packet.

View file

@ -5,7 +5,6 @@ type StreamID uint64
// MaxBidiStreamID is the highest stream ID that the peer is allowed to open,
// when it is allowed to open numStreams bidirectional streams.
// It is only valid for IETF QUIC.
func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
if numStreams == 0 {
return 0
@ -21,7 +20,6 @@ func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
// MaxUniStreamID is the highest stream ID that the peer is allowed to open,
// when it is allowed to open numStreams unidirectional streams.
// It is only valid for IETF QUIC.
func MaxUniStreamID(numStreams int, pers Perspective) StreamID {
if numStreams == 0 {
return 0

View file

@ -18,32 +18,20 @@ const (
// The version numbers, making grepping easier
const (
Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9
Version43 VersionNumber = gquicVersion0 + 4*0x100 + 0x3
Version44 VersionNumber = gquicVersion0 + 4*0x100 + 0x4
VersionTLS VersionNumber = 101
VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionWhatever VersionNumber = 1 // for when the version doesn't matter
VersionUnknown VersionNumber = math.MaxUint32
)
// SupportedVersions lists the versions that the server supports
// must be in sorted descending order
var SupportedVersions = []VersionNumber{
Version44,
Version43,
Version39,
}
var SupportedVersions = []VersionNumber{VersionTLS}
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool {
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
}
// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake
func (vn VersionNumber) UsesTLS() bool {
return !vn.isGQUIC()
}
func (vn VersionNumber) String() string {
switch vn {
case VersionWhatever:
@ -62,56 +50,9 @@ func (vn VersionNumber) String() string {
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
func (vn VersionNumber) ToAltSvc() string {
if vn.isGQUIC() {
return fmt.Sprintf("%d", vn.toGQUICVersion())
}
return fmt.Sprintf("%d", vn)
}
// IsCryptoStream says if a stream is the gQUIC crypto stream.
// It never returns true for IETF QUIC.
func (vn VersionNumber) IsCryptoStream(id StreamID) bool {
return vn.isGQUIC() && id == 1
}
// UsesIETFFrameFormat tells if this version uses the IETF frame format
func (vn VersionNumber) UsesIETFFrameFormat() bool {
return !vn.isGQUIC()
}
// UsesIETFHeaderFormat tells if this version uses the IETF header format
func (vn VersionNumber) UsesIETFHeaderFormat() bool {
return !vn.isGQUIC() || vn >= Version44
}
// UsesLengthInHeader tells if this version uses the Length field in the IETF header
func (vn VersionNumber) UsesLengthInHeader() bool {
return !vn.isGQUIC()
}
// UsesTokenInHeader tells if this version uses the Token field in the IETF header
func (vn VersionNumber) UsesTokenInHeader() bool {
return !vn.isGQUIC()
}
// UsesStopWaitingFrames tells if this version uses STOP_WAITING frames
func (vn VersionNumber) UsesStopWaitingFrames() bool {
return vn.isGQUIC() && vn <= Version43
}
// UsesVarintPacketNumbers tells if this version uses 7/14/30 bit packet numbers
func (vn VersionNumber) UsesVarintPacketNumbers() bool {
return !vn.isGQUIC()
}
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
if !vn.isGQUIC() {
return true
}
return id != 1 && id != 3
}
func (vn VersionNumber) isGQUIC() bool {
return vn > gquicVersion0 && vn <= maxGquicVersion
}

View file

@ -10,39 +10,18 @@ var _ = Describe("Version", func() {
return v&0x0f0f0f0f == 0x0a0a0a0a
}
// version numbers taken from the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
It("has the right gQUIC version number", func() {
Expect(Version39).To(BeEquivalentTo(0x51303339))
Expect(Version43).To(BeEquivalentTo(0x51303433))
Expect(Version44).To(BeEquivalentTo(0x51303434))
})
It("says if a version is valid", func() {
Expect(IsValidVersion(Version39)).To(BeTrue())
Expect(IsValidVersion(Version43)).To(BeTrue())
Expect(IsValidVersion(Version44)).To(BeTrue())
Expect(IsValidVersion(VersionTLS)).To(BeTrue())
Expect(IsValidVersion(VersionWhatever)).To(BeFalse())
Expect(IsValidVersion(VersionUnknown)).To(BeFalse())
Expect(IsValidVersion(1234)).To(BeFalse())
})
It("says if a version supports TLS", func() {
Expect(Version39.UsesTLS()).To(BeFalse())
Expect(Version43.UsesTLS()).To(BeFalse())
Expect(Version44.UsesTLS()).To(BeFalse())
Expect(VersionTLS.UsesTLS()).To(BeTrue())
})
It("versions don't have reserved version numbers", func() {
Expect(isReservedVersion(Version39)).To(BeFalse())
Expect(isReservedVersion(Version43)).To(BeFalse())
Expect(isReservedVersion(Version44)).To(BeFalse())
Expect(isReservedVersion(VersionTLS)).To(BeFalse())
})
It("has the right string representation", func() {
Expect(Version39.String()).To(Equal("gQUIC 39"))
Expect(VersionTLS.String()).To(ContainSubstring("TLS"))
Expect(VersionWhatever.String()).To(Equal("whatever"))
Expect(VersionUnknown.String()).To(Equal("unknown"))
@ -55,88 +34,7 @@ var _ = Describe("Version", func() {
})
It("has the right representation for the H2 Alt-Svc tag", func() {
Expect(Version39.ToAltSvc()).To(Equal("39"))
Expect(Version43.ToAltSvc()).To(Equal("43"))
Expect(Version44.ToAltSvc()).To(Equal("44"))
Expect(VersionTLS.ToAltSvc()).To(Equal("101"))
// check with unsupported version numbers from the wiki
Expect(VersionNumber(0x51303133).ToAltSvc()).To(Equal("13"))
Expect(VersionNumber(0x51303235).ToAltSvc()).To(Equal("25"))
Expect(VersionNumber(0x51303438).ToAltSvc()).To(Equal("48"))
})
It("says if a stream is the crypto stream, for gQUIC", func() {
for _, v := range []VersionNumber{Version39, Version43, Version44} {
version := v
Expect(version.IsCryptoStream(1)).To(BeTrue())
Expect(version.IsCryptoStream(2)).To(BeFalse())
Expect(version.IsCryptoStream(3)).To(BeFalse())
Expect(version.IsCryptoStream(4)).To(BeFalse())
Expect(version.IsCryptoStream(5)).To(BeFalse())
}
})
It("says if a stream is the crypto stream, for TLS", func() {
// all streams contribute to connection-level flow control
for id := StreamID(0); id < 10; id++ {
Expect(VersionTLS.IsCryptoStream(id)).To(BeFalse())
}
})
It("tells if a version uses the IETF frame types", func() {
Expect(Version39.UsesIETFFrameFormat()).To(BeFalse())
Expect(Version43.UsesIETFFrameFormat()).To(BeFalse())
Expect(Version44.UsesIETFFrameFormat()).To(BeFalse())
Expect(VersionTLS.UsesIETFFrameFormat()).To(BeTrue())
})
It("tells if a version uses the IETF header format", func() {
Expect(Version39.UsesIETFHeaderFormat()).To(BeFalse())
Expect(Version43.UsesIETFHeaderFormat()).To(BeFalse())
Expect(Version44.UsesIETFHeaderFormat()).To(BeTrue())
Expect(VersionTLS.UsesIETFHeaderFormat()).To(BeTrue())
})
It("tells if a version uses varint packet numbers", func() {
Expect(Version39.UsesVarintPacketNumbers()).To(BeFalse())
Expect(Version43.UsesVarintPacketNumbers()).To(BeFalse())
Expect(Version44.UsesVarintPacketNumbers()).To(BeFalse())
Expect(VersionTLS.UsesVarintPacketNumbers()).To(BeTrue())
})
It("tells if a version uses the Length field in the IETF header", func() {
Expect(Version44.UsesLengthInHeader()).To(BeFalse())
Expect(VersionTLS.UsesLengthInHeader()).To(BeTrue())
})
It("tells if a version uses the Token field in the IETF header", func() {
Expect(Version44.UsesTokenInHeader()).To(BeFalse())
Expect(VersionTLS.UsesTokenInHeader()).To(BeTrue())
})
It("tells if a version uses STOP_WAITING frames", func() {
Expect(Version39.UsesStopWaitingFrames()).To(BeTrue())
Expect(Version43.UsesStopWaitingFrames()).To(BeTrue())
Expect(Version44.UsesStopWaitingFrames()).To(BeFalse())
Expect(VersionTLS.UsesStopWaitingFrames()).To(BeFalse())
})
It("says if a stream contributes to connection-level flowcontrol, for gQUIC", func() {
for _, v := range []VersionNumber{Version39, Version43, Version44} {
version := v
Expect(version.StreamContributesToConnectionFlowControl(1)).To(BeFalse())
Expect(version.StreamContributesToConnectionFlowControl(2)).To(BeTrue())
Expect(version.StreamContributesToConnectionFlowControl(3)).To(BeFalse())
Expect(version.StreamContributesToConnectionFlowControl(4)).To(BeTrue())
Expect(version.StreamContributesToConnectionFlowControl(5)).To(BeTrue())
}
})
It("says if a stream contributes to connection-level flowcontrol, for TLS", func() {
// all streams contribute to connection-level flow control
for id := StreamID(0); id < 10; id++ {
Expect(VersionTLS.StreamContributesToConnectionFlowControl(id)).To(BeTrue())
}
})
It("recognizes supported versions", func() {

View file

@ -1,22 +0,0 @@
package utils
import "sync/atomic"
// An AtomicBool is an atomic bool
type AtomicBool struct {
v int32
}
// Set sets the value
func (a *AtomicBool) Set(value bool) {
var n int32
if value {
n = 1
}
atomic.StoreInt32(&a.v, n)
}
// Get gets the value
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.v) != 0
}

View file

@ -1,29 +0,0 @@
package utils
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Atomic Bool", func() {
var a *AtomicBool
BeforeEach(func() {
a = &AtomicBool{}
})
It("has the right default value", func() {
Expect(a.Get()).To(BeFalse())
})
It("sets the value to true", func() {
a.Set(true)
Expect(a.Get()).To(BeTrue())
})
It("sets the value to false", func() {
a.Set(true)
a.Set(false)
Expect(a.Get()).To(BeFalse())
})
})

View file

@ -75,27 +75,6 @@ var _ = Describe("Big Endian encoding / decoding", func() {
})
})
Context("WriteUint24", func() {
It("outputs 3 bytes", func() {
b := &bytes.Buffer{}
BigEndian.WriteUint24(b, uint32(1))
Expect(b.Len()).To(Equal(3))
})
It("outputs a big endian", func() {
num := uint32(0x010203)
b := &bytes.Buffer{}
BigEndian.WriteUint24(b, num)
Expect(b.Bytes()).To(Equal([]byte{0x01, 0x02, 0x03}))
})
It("panics if the value doesn't fit into 24 bits", func() {
num := uint32(0x01020304)
b := &bytes.Buffer{}
Expect(func() { BigEndian.WriteUint24(b, num) }).Should(Panic())
})
})
Context("WriteUint32", func() {
It("outputs 4 bytes", func() {
b := &bytes.Buffer{}
@ -111,69 +90,6 @@ var _ = Describe("Big Endian encoding / decoding", func() {
})
})
Context("WriteUint40", func() {
It("outputs 5 bytes", func() {
b := &bytes.Buffer{}
BigEndian.WriteUint40(b, uint64(1))
Expect(b.Len()).To(Equal(5))
})
It("outputs a big endian", func() {
num := uint64(0xDECAFBAD42)
b := &bytes.Buffer{}
BigEndian.WriteUint40(b, num)
Expect(b.Bytes()).To(Equal([]byte{0xDE, 0xCA, 0xFB, 0xAD, 0x42}))
})
It("panics if the value doesn't fit into 40 bits", func() {
num := uint64(0x010203040506)
b := &bytes.Buffer{}
Expect(func() { BigEndian.WriteUint40(b, num) }).Should(Panic())
})
})
Context("WriteUint48", func() {
It("outputs 6 bytes", func() {
b := &bytes.Buffer{}
BigEndian.WriteUint48(b, uint64(1))
Expect(b.Len()).To(Equal(6))
})
It("outputs a big endian", func() {
num := uint64(0xDEADBEEFCAFE)
b := &bytes.Buffer{}
BigEndian.WriteUint48(b, num)
Expect(b.Bytes()).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE}))
})
It("panics if the value doesn't fit into 48 bits", func() {
num := uint64(0xDEADBEEFCAFE01)
b := &bytes.Buffer{}
Expect(func() { BigEndian.WriteUint48(b, num) }).Should(Panic())
})
})
Context("WriteUint56", func() {
It("outputs 7 bytes", func() {
b := &bytes.Buffer{}
BigEndian.WriteUint56(b, uint64(1))
Expect(b.Len()).To(Equal(7))
})
It("outputs a big endian", func() {
num := uint64(0xEEDDCCBBAA9988)
b := &bytes.Buffer{}
BigEndian.WriteUint56(b, num)
Expect(b.Bytes()).To(Equal([]byte{0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88}))
})
It("panics if the value doesn't fit into 56 bits", func() {
num := uint64(0xEEDDCCBBAA998801)
b := &bytes.Buffer{}
Expect(func() { BigEndian.WriteUint56(b, num) }).Should(Panic())
})
})
Context("WriteUint64", func() {
It("outputs 8 bytes", func() {
b := &bytes.Buffer{}

View file

@ -13,13 +13,6 @@ type ByteOrder interface {
ReadUint16(io.ByteReader) (uint16, error)
WriteUint64(*bytes.Buffer, uint64)
WriteUint56(*bytes.Buffer, uint64)
WriteUint48(*bytes.Buffer, uint64)
WriteUint40(*bytes.Buffer, uint64)
WriteUint32(*bytes.Buffer, uint32)
WriteUint24(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
ReadUfloat16(io.ByteReader) (uint64, error)
WriteUfloat16(*bytes.Buffer, uint64)
}

View file

@ -2,7 +2,6 @@ package utils
import (
"bytes"
"fmt"
"io"
)
@ -97,61 +96,12 @@ func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
})
}
// WriteUint56 writes 56 bit of a uint64
func (bigEndian) WriteUint56(b *bytes.Buffer, i uint64) {
if i >= (1 << 56) {
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
}
b.Write([]byte{
uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint48 writes 48 bit of a uint64
func (bigEndian) WriteUint48(b *bytes.Buffer, i uint64) {
if i >= (1 << 48) {
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
}
b.Write([]byte{
uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint40 writes 40 bit of a uint64
func (bigEndian) WriteUint40(b *bytes.Buffer, i uint64) {
if i >= (1 << 40) {
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
}
b.Write([]byte{
uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint32 writes a uint32
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint24 writes 24 bit of a uint32
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
if i >= (1 << 24) {
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
}
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint16 writes a uint16
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i >> 8), uint8(i)})
}
func (l bigEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
return readUfloat16(b, l)
}
func (l bigEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
writeUfloat16(b, l, val)
}

View file

@ -1,157 +0,0 @@
package utils
import (
"bytes"
"fmt"
"io"
)
// LittleEndian is the little-endian implementation of ByteOrder.
var LittleEndian ByteOrder = littleEndian{}
type littleEndian struct{}
var _ ByteOrder = &littleEndian{}
// ReadUintN reads N bytes
func (littleEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64
for i := uint8(0); i < length; i++ {
bt, err := b.ReadByte()
if err != nil {
return 0, err
}
res ^= uint64(bt) << (i * 8)
}
return res, nil
}
// ReadUint64 reads a uint64
func (littleEndian) ReadUint64(b io.ByteReader) (uint64, error) {
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b5, err = b.ReadByte(); err != nil {
return 0, err
}
if b6, err = b.ReadByte(); err != nil {
return 0, err
}
if b7, err = b.ReadByte(); err != nil {
return 0, err
}
if b8, err = b.ReadByte(); err != nil {
return 0, err
}
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
}
// ReadUint32 reads a uint32
func (littleEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
// ReadUint16 reads a uint16
func (littleEndian) ReadUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
return uint16(b1) + uint16(b2)<<8, nil
}
// WriteUint64 writes a uint64
func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) {
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), uint8(i >> 56),
})
}
// WriteUint56 writes 56 bit of a uint64
func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) {
if i >= (1 << 56) {
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48),
})
}
// WriteUint48 writes 48 bit of a uint64
func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) {
if i >= (1 << 48) {
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40),
})
}
// WriteUint40 writes 40 bit of a uint64
func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) {
if i >= (1 << 40) {
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16),
uint8(i >> 24), uint8(i >> 32),
})
}
// WriteUint32 writes a uint32
func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24)})
}
// WriteUint24 writes 24 bit of a uint32
func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) {
if i >= (1 << 24) {
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
}
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)})
}
// WriteUint16 writes a uint16
func (littleEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i), uint8(i >> 8)})
}
func (l littleEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
return readUfloat16(b, l)
}
func (l littleEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
writeUfloat16(b, l, val)
}

View file

@ -1,212 +0,0 @@
package utils
import (
"bytes"
"io"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Little Endian encoding / decoding", func() {
Context("ReadUint16", func() {
It("reads a little endian", func() {
b := []byte{0x13, 0xEF}
val, err := LittleEndian.ReadUint16(bytes.NewReader(b))
Expect(err).ToNot(HaveOccurred())
Expect(val).To(Equal(uint16(0xEF13)))
})
It("throws an error if less than 2 bytes are passed", func() {
b := []byte{0x13, 0xEF}
for i := 0; i < len(b); i++ {
_, err := LittleEndian.ReadUint16(bytes.NewReader(b[:i]))
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("ReadUint32", func() {
It("reads a little endian", func() {
b := []byte{0x12, 0x35, 0xAB, 0xFF}
val, err := LittleEndian.ReadUint32(bytes.NewReader(b))
Expect(err).ToNot(HaveOccurred())
Expect(val).To(Equal(uint32(0xFFAB3512)))
})
It("throws an error if less than 4 bytes are passed", func() {
b := []byte{0x12, 0x35, 0xAB, 0xFF}
for i := 0; i < len(b); i++ {
_, err := LittleEndian.ReadUint32(bytes.NewReader(b[:i]))
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("ReadUint64", func() {
It("reads a little endian", func() {
b := []byte{0x12, 0x35, 0xAB, 0xFF, 0xEF, 0xBE, 0xAD, 0xDE}
val, err := LittleEndian.ReadUint64(bytes.NewReader(b))
Expect(err).ToNot(HaveOccurred())
Expect(val).To(Equal(uint64(0xDEADBEEFFFAB3512)))
})
It("throws an error if less than 8 bytes are passed", func() {
b := []byte{0x12, 0x35, 0xAB, 0xFF, 0xEF, 0xBE, 0xAD, 0xDE}
for i := 0; i < len(b); i++ {
_, err := LittleEndian.ReadUint64(bytes.NewReader(b[:i]))
Expect(err).To(MatchError(io.EOF))
}
})
})
Context("WriteUint16", func() {
It("outputs 2 bytes", func() {
b := &bytes.Buffer{}
LittleEndian.WriteUint16(b, uint16(1))
Expect(b.Len()).To(Equal(2))
})
It("outputs a little endian", func() {
num := uint16(0xFF11)
b := &bytes.Buffer{}
LittleEndian.WriteUint16(b, num)
Expect(b.Bytes()).To(Equal([]byte{0x11, 0xFF}))
})
})
Context("WriteUint24", func() {
It("outputs 3 bytes", func() {
b := &bytes.Buffer{}
LittleEndian.WriteUint24(b, uint32(1))
Expect(b.Len()).To(Equal(3))
})
It("outputs a little endian", func() {
num := uint32(0x010203)
b := &bytes.Buffer{}
LittleEndian.WriteUint24(b, num)
Expect(b.Bytes()).To(Equal([]byte{0x03, 0x02, 0x01}))
})
It("panics if the value doesn't fit into 24 bits", func() {
num := uint32(0x01020304)
b := &bytes.Buffer{}
Expect(func() { LittleEndian.WriteUint24(b, num) }).Should(Panic())
})
})
Context("WriteUint32", func() {
It("outputs 4 bytes", func() {
b := &bytes.Buffer{}
LittleEndian.WriteUint32(b, uint32(1))
Expect(b.Len()).To(Equal(4))
})
It("outputs a little endian", func() {
num := uint32(0xEFAC3512)
b := &bytes.Buffer{}
LittleEndian.WriteUint32(b, num)
Expect(b.Bytes()).To(Equal([]byte{0x12, 0x35, 0xAC, 0xEF}))
})
})
Context("WriteUint40", func() {
It("outputs 5 bytes", func() {
b := &bytes.Buffer{}
LittleEndian.WriteUint40(b, uint64(1))
Expect(b.Len()).To(Equal(5))
})
It("outputs a little endian", func() {
num := uint64(0x0102030405)
b := &bytes.Buffer{}
LittleEndian.WriteUint40(b, num)
Expect(b.Bytes()).To(Equal([]byte{0x05, 0x04, 0x03, 0x02, 0x01}))
})
It("panics if the value doesn't fit into 40 bits", func() {
num := uint64(0x010203040506)
b := &bytes.Buffer{}
Expect(func() { LittleEndian.WriteUint40(b, num) }).Should(Panic())
})
})
Context("WriteUint48", func() {
It("outputs 6 bytes", func() {
b := &bytes.Buffer{}
LittleEndian.WriteUint48(b, uint64(1))
Expect(b.Len()).To(Equal(6))
})
It("outputs a little endian", func() {
num := uint64(0xDEADBEEFCAFE)
b := &bytes.Buffer{}
LittleEndian.WriteUint48(b, num)
Expect(b.Bytes()).To(Equal([]byte{0xFE, 0xCA, 0xEF, 0xBE, 0xAD, 0xDE}))
})
It("panics if the value doesn't fit into 48 bits", func() {
num := uint64(0xDEADBEEFCAFE01)
b := &bytes.Buffer{}
Expect(func() { LittleEndian.WriteUint48(b, num) }).Should(Panic())
})
})
Context("WriteUint56", func() {
It("outputs 7 bytes", func() {
b := &bytes.Buffer{}
LittleEndian.WriteUint56(b, uint64(1))
Expect(b.Len()).To(Equal(7))
})
It("outputs a little endian", func() {
num := uint64(0xEEDDCCBBAA9988)
b := &bytes.Buffer{}
LittleEndian.WriteUint56(b, num)
Expect(b.Bytes()).To(Equal([]byte{0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE}))
})
It("panics if the value doesn't fit into 56 bits", func() {
num := uint64(0xEEDDCCBBAA998801)
b := &bytes.Buffer{}
Expect(func() { LittleEndian.WriteUint56(b, num) }).Should(Panic())
})
})
Context("WriteUint64", func() {
It("outputs 8 bytes", func() {
b := &bytes.Buffer{}
LittleEndian.WriteUint64(b, uint64(1))
Expect(b.Len()).To(Equal(8))
})
It("outputs a little endian", func() {
num := uint64(0xFFEEDDCCBBAA9988)
b := &bytes.Buffer{}
LittleEndian.WriteUint64(b, num)
Expect(b.Bytes()).To(Equal([]byte{0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}))
})
})
Context("ReadUintN", func() {
It("reads n bytes", func() {
m := map[uint8]uint64{
0: 0x0, 1: 0x01, 2: 0x0201, 3: 0x030201, 4: 0x04030201, 5: 0x0504030201,
6: 0x060504030201, 7: 0x07060504030201, 8: 0x0807060504030201,
}
for n, expected := range m {
b := bytes.NewReader([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})
i, err := LittleEndian.ReadUintN(b, n)
Expect(err).ToNot(HaveOccurred())
Expect(i).To(Equal(expected))
}
})
It("errors", func() {
b := bytes.NewReader([]byte{0x1, 0x2})
_, err := LittleEndian.ReadUintN(b, 3)
Expect(err).To(HaveOccurred())
})
})
})

Some files were not shown because too many files have changed in this diff Show more