mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
drop support for gQUIC
This commit is contained in:
parent
8f8ed03254
commit
3266e36811
195 changed files with 2638 additions and 35430 deletions
|
@ -14,7 +14,6 @@ defaults: &defaults
|
|||
command: |
|
||||
echo $GOARCH
|
||||
go version
|
||||
google-chrome --version
|
||||
printf "quic.clemente.io certificate valid until: " && openssl x509 -in example/fullchain.pem -enddate -noout | cut -d = -f 2
|
||||
- run:
|
||||
name: "Run benchmark tests"
|
||||
|
@ -25,12 +24,6 @@ defaults: &defaults
|
|||
- run:
|
||||
name: "Run tools tests"
|
||||
command: ginkgo -r -v -randomizeAllSpecs -trace integrationtests/tools
|
||||
- run:
|
||||
name: "Run Chrome integration tests"
|
||||
command: ginkgo -v -randomizeAllSpecs -trace integrationtests/chrome
|
||||
- run:
|
||||
name: "Run integration tests using the gQUIC toy server / client"
|
||||
command: ginkgo -v -randomizeAllSpecs -trace integrationtests/gquic
|
||||
- run:
|
||||
name: "Run self integration tests"
|
||||
command: ginkgo -v -randomizeAllSpecs -trace integrationtests/self
|
||||
|
|
|
@ -22,5 +22,5 @@ if [ ${TESTMODE} == "integration" ]; then
|
|||
ginkgo -race -randomizeAllSpecs -randomizeSuites -trace benchmark -- -samples=1 -size=10
|
||||
fi
|
||||
# run integration tests
|
||||
ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage chrome integrationtests
|
||||
ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage integrationtests
|
||||
fi
|
||||
|
|
33
README.md
33
README.md
|
@ -8,12 +8,19 @@
|
|||
[](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
|
||||
[](https://codecov.io/gh/lucas-clemente/quic-go/)
|
||||
|
||||
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
|
||||
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go. It roughly implements the [IETF QUIC draft](https://github.com/quicwg/base-drafts), although we don't fully support any of the draft versions at the moment.
|
||||
|
||||
## Roadmap
|
||||
## Version compatibility
|
||||
|
||||
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
|
||||
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
|
||||
Since quic-go is under active development, there's no guarantee that two builds of different commits are interoperable. The QUIC version used in the *master* branch is just a placeholder, and should not be considered stable.
|
||||
|
||||
If you want to use quic-go as a library in other projects, please consider using a [tagged release](https://github.com/lucas-clemente/quic-go/releases). These releases expose [experimental QUIC versions](https://github.com/quicwg/base-drafts/wiki/QUIC-Versions), which are guaranteed to be stable.
|
||||
|
||||
## Google QUIC
|
||||
|
||||
quic-go used to support both the QUIC versions supported by Google Chrome and QUIC as deployed on Google's servers, as well as IETF QUIC. Due to the divergence of the two protocols, we decided to not support both versions any more.
|
||||
|
||||
The *master* branch **only** supports IETF QUIC. For Google QUIC support, please refer to the [gquic branch](https://github.com/lucas-clemente/quic-go/tree/gquic).
|
||||
|
||||
## Guides
|
||||
|
||||
|
@ -27,31 +34,19 @@ Running tests:
|
|||
|
||||
go test ./...
|
||||
|
||||
### Running the example server
|
||||
### HTTP mapping
|
||||
|
||||
go run example/main.go -www /var/www/
|
||||
|
||||
Using the `quic_client` from chromium:
|
||||
|
||||
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
|
||||
|
||||
Using Chrome:
|
||||
|
||||
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
|
||||
We're currently not implementing the HTTP mapping as described in the [QUIC over HTTP draft](https://quicwg.org/base-drafts/draft-ietf-quic-http.html). The HTTP mapping here is a leftover from Google QUIC.
|
||||
|
||||
### QUIC without HTTP/2
|
||||
|
||||
Take a look at [this echo example](example/echo/echo.go).
|
||||
|
||||
### Using the example client
|
||||
|
||||
go run example/client/main.go https://clemente.io
|
||||
|
||||
## Usage
|
||||
|
||||
### As a server
|
||||
|
||||
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
|
||||
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
|
||||
|
||||
```go
|
||||
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
||||
|
|
127
client.go
127
client.go
|
@ -1,7 +1,6 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
|
@ -123,13 +122,6 @@ func dialContext(
|
|||
createdPacketConn bool,
|
||||
) (Session, error) {
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
if !createdPacketConn {
|
||||
for _, v := range config.Versions {
|
||||
if v == protocol.Version44 {
|
||||
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
|
||||
}
|
||||
}
|
||||
}
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -234,17 +226,11 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
|||
if connIDLen == 0 && !createdPacketConn {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
for _, v := range versions {
|
||||
if v == protocol.Version44 {
|
||||
connIDLen = 0
|
||||
}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
IdleTimeout: idleTimeout,
|
||||
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
|
||||
ConnectionIDLength: connIDLen,
|
||||
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
|
||||
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
|
||||
|
@ -255,53 +241,22 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
|||
}
|
||||
|
||||
func (c *client) generateConnectionIDs() error {
|
||||
connIDLen := protocol.ConnectionIDLenGQUIC
|
||||
if c.version.UsesTLS() {
|
||||
connIDLen = c.config.ConnectionIDLength
|
||||
}
|
||||
srcConnID, err := generateConnectionID(connIDLen)
|
||||
srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destConnID := srcConnID
|
||||
if c.version.UsesTLS() {
|
||||
destConnID, err = generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destConnID, err := generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.srcConnID = srcConnID
|
||||
c.destConnID = destConnID
|
||||
if c.version == protocol.Version44 {
|
||||
c.srcConnID = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
var err error
|
||||
if c.version.UsesTLS() {
|
||||
err = c.dialTLS(ctx)
|
||||
} else {
|
||||
err = c.dialGQUIC(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialGQUIC(ctx context.Context) error {
|
||||
if err := c.createNewGQUICSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
err := c.establishSecureConnection(ctx)
|
||||
if err == errCloseSessionForNewVersion {
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) dialTLS(ctx context.Context) error {
|
||||
if err := c.createNewTLSSession(c.version); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -315,9 +270,9 @@ func (c *client) dialTLS(ctx context.Context) error {
|
|||
// establishSecureConnection runs the session, and tries to establish a secure connection
|
||||
// It returns:
|
||||
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
|
||||
// - any other error that might occur
|
||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||
// - when the connection is forward-secure
|
||||
func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
|
@ -362,24 +317,9 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if !c.version.UsesIETFHeaderFormat() {
|
||||
connID := p.header.DestConnectionID
|
||||
// reject packets with truncated connection id if we didn't request truncation
|
||||
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
|
||||
return errors.New("received packet with truncated connection ID, but didn't request truncation")
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
|
||||
}
|
||||
if p.header.ResetFlag {
|
||||
return c.handlePublicReset(p)
|
||||
}
|
||||
} else {
|
||||
// reject packets with the wrong connection ID
|
||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
// reject packets with the wrong connection ID
|
||||
if !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
|
||||
}
|
||||
|
||||
if p.header.Type == protocol.PacketTypeRetry {
|
||||
|
@ -397,22 +337,6 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handlePublicReset(p *receivedPacket) error {
|
||||
cr := c.conn.RemoteAddr()
|
||||
// check if the remote address and the connection ID match
|
||||
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
|
||||
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
|
||||
return errors.New("Received a spoofed Public Reset")
|
||||
}
|
||||
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
|
||||
}
|
||||
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
|
||||
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
|
||||
// ignore delayed / duplicated version negotiation packets
|
||||
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
|
||||
|
@ -471,42 +395,11 @@ func (c *client) handleRetryPacket(hdr *wire.Header) {
|
|||
c.session.destroy(errCloseSessionForRetry)
|
||||
}
|
||||
|
||||
func (c *client) createNewGQUICSession() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
sess, err := newClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.version,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.tlsConf,
|
||||
c.config,
|
||||
c.initialVersion,
|
||||
c.negotiatedVersions,
|
||||
c.logger,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.session = sess
|
||||
c.packetHandlers.Add(c.srcConnID, c)
|
||||
if c.config.RequestConnectionIDOmission {
|
||||
c.packetHandlers.Add(protocol.ConnectionID{}, c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
|
||||
params := &handshake.TransportParameters{
|
||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||
IdleTimeout: c.config.IdleTimeout,
|
||||
OmitConnectionID: c.config.RequestConnectionIDOmission,
|
||||
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
|
||||
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
|
||||
DisableMigration: true,
|
||||
|
@ -518,7 +411,7 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
|
|||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
sess, err := newTLSClientSession(
|
||||
sess, err := newClientSession(
|
||||
c.conn,
|
||||
runner,
|
||||
c.token,
|
||||
|
|
694
client_test.go
694
client_test.go
|
@ -30,9 +30,20 @@ var _ = Describe("Client", func() {
|
|||
mockMultiplexer *MockMultiplexer
|
||||
origMultiplexer multiplexer
|
||||
|
||||
supportedVersionsWithoutGQUIC44 []protocol.VersionNumber
|
||||
|
||||
originalClientSessConstructor func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *tls.Config, *Config, protocol.VersionNumber, []protocol.VersionNumber, utils.Logger) (quicSession, error)
|
||||
originalClientSessConstructor func(
|
||||
conn connection,
|
||||
runner sessionRunner,
|
||||
token []byte,
|
||||
destConnID protocol.ConnectionID,
|
||||
srcConnID protocol.ConnectionID,
|
||||
conf *Config,
|
||||
tlsConf *tls.Config,
|
||||
params *handshake.TransportParameters,
|
||||
initialVersion protocol.VersionNumber,
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
) (quicSession, error)
|
||||
)
|
||||
|
||||
// generate a packet sent by the server that accepts the QUIC version suggested by the client
|
||||
|
@ -79,12 +90,6 @@ var _ = Describe("Client", func() {
|
|||
mockMultiplexer = NewMockMultiplexer(mockCtrl)
|
||||
origMultiplexer = connMuxer
|
||||
connMuxer = mockMultiplexer
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
if v != protocol.Version44 {
|
||||
supportedVersionsWithoutGQUIC44 = append(supportedVersionsWithoutGQUIC44, v)
|
||||
}
|
||||
}
|
||||
Expect(supportedVersionsWithoutGQUIC44).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
@ -131,14 +136,16 @@ var _ = Describe("Client", func() {
|
|||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
remoteAddrChan <- conn.RemoteAddr().String()
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
@ -159,14 +166,16 @@ var _ = Describe("Client", func() {
|
|||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
tlsConf *tls.Config,
|
||||
_ *Config,
|
||||
tlsConf *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
hostnameChan <- tlsConf.ServerName
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
@ -187,14 +196,16 @@ var _ = Describe("Client", func() {
|
|||
newClientSession = func(
|
||||
_ connection,
|
||||
runner sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run().Do(func() { close(run) })
|
||||
|
@ -206,25 +217,13 @@ var _ = Describe("Client", func() {
|
|||
addr,
|
||||
"quic.clemente.io:1337",
|
||||
nil,
|
||||
&Config{Versions: supportedVersionsWithoutGQUIC44},
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(s).ToNot(BeNil())
|
||||
Eventually(run).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("refuses to multiplex gQUIC 44", func() {
|
||||
_, err := Dial(
|
||||
packetConn,
|
||||
addr,
|
||||
"quic.clemente.io:1337",
|
||||
nil,
|
||||
&Config{Versions: []protocol.VersionNumber{protocol.Version44}},
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("Cannot multiplex connections using gQUIC 44"))
|
||||
})
|
||||
|
||||
It("returns an error that occurs while waiting for the connection to become secure", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
|
@ -232,16 +231,18 @@ var _ = Describe("Client", func() {
|
|||
|
||||
testErr := errors.New("early handshake error")
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run().Return(testErr)
|
||||
|
@ -253,7 +254,7 @@ var _ = Describe("Client", func() {
|
|||
addr,
|
||||
"quic.clemente.io:1337",
|
||||
nil,
|
||||
&Config{Versions: supportedVersionsWithoutGQUIC44},
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
@ -270,16 +271,18 @@ var _ = Describe("Client", func() {
|
|||
<-sessionRunning
|
||||
})
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
return sess, nil
|
||||
}
|
||||
|
@ -293,7 +296,7 @@ var _ = Describe("Client", func() {
|
|||
addr,
|
||||
"quic.clemnte.io:1337",
|
||||
nil,
|
||||
&Config{Versions: supportedVersionsWithoutGQUIC44},
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).To(MatchError(context.Canceled))
|
||||
close(dialed)
|
||||
|
@ -313,16 +316,18 @@ var _ = Describe("Client", func() {
|
|||
var runner sessionRunner
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ connection,
|
||||
runnerP sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
runner = runnerP
|
||||
return sess, nil
|
||||
|
@ -337,7 +342,7 @@ var _ = Describe("Client", func() {
|
|||
addr,
|
||||
"quic.clemnte.io:1337",
|
||||
nil,
|
||||
&Config{Versions: supportedVersionsWithoutGQUIC44},
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
@ -354,14 +359,16 @@ var _ = Describe("Client", func() {
|
|||
newClientSession = func(
|
||||
connP connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
conn = connP
|
||||
close(sessionCreated)
|
||||
|
@ -397,39 +404,20 @@ var _ = Describe("Client", func() {
|
|||
Context("quic.Config", func() {
|
||||
It("setups with the right values", func() {
|
||||
config := &Config{
|
||||
HandshakeTimeout: 1337 * time.Minute,
|
||||
IdleTimeout: 42 * time.Hour,
|
||||
RequestConnectionIDOmission: true,
|
||||
MaxIncomingStreams: 1234,
|
||||
MaxIncomingUniStreams: 4321,
|
||||
ConnectionIDLength: 13,
|
||||
Versions: supportedVersionsWithoutGQUIC44,
|
||||
HandshakeTimeout: 1337 * time.Minute,
|
||||
IdleTimeout: 42 * time.Hour,
|
||||
MaxIncomingStreams: 1234,
|
||||
MaxIncomingUniStreams: 4321,
|
||||
ConnectionIDLength: 13,
|
||||
}
|
||||
c := populateClientConfig(config, false)
|
||||
Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute))
|
||||
Expect(c.IdleTimeout).To(Equal(42 * time.Hour))
|
||||
Expect(c.RequestConnectionIDOmission).To(BeTrue())
|
||||
Expect(c.MaxIncomingStreams).To(Equal(1234))
|
||||
Expect(c.MaxIncomingUniStreams).To(Equal(4321))
|
||||
Expect(c.ConnectionIDLength).To(Equal(13))
|
||||
})
|
||||
|
||||
It("uses a 0 byte connection IDs if gQUIC 44 is supported", func() {
|
||||
config := &Config{
|
||||
Versions: []protocol.VersionNumber{protocol.Version43, protocol.Version44},
|
||||
ConnectionIDLength: 13,
|
||||
}
|
||||
c := populateClientConfig(config, false)
|
||||
Expect(c.Versions).To(Equal([]protocol.VersionNumber{protocol.Version43, protocol.Version44}))
|
||||
Expect(c.ConnectionIDLength).To(BeZero())
|
||||
})
|
||||
|
||||
It("doesn't use 0-byte connection IDs when dialing an address", func() {
|
||||
config := &Config{Versions: supportedVersionsWithoutGQUIC44}
|
||||
c := populateClientConfig(config, false)
|
||||
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
|
||||
})
|
||||
|
||||
It("errors when the Config contains an invalid version", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
|
@ -470,194 +458,160 @@ var _ = Describe("Client", func() {
|
|||
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
|
||||
Expect(c.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout))
|
||||
Expect(c.IdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
|
||||
Expect(c.RequestConnectionIDOmission).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("gQUIC", func() {
|
||||
It("errors if it can't create a session", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
It("creates new TLS sessions with the right parameters", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
|
||||
testErr := errors.New("error creating session")
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ utils.Logger,
|
||||
) (quicSession, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err := Dial(
|
||||
packetConn,
|
||||
addr,
|
||||
"quic.clemente.io:1337",
|
||||
nil,
|
||||
&Config{Versions: supportedVersionsWithoutGQUIC44},
|
||||
)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IETF QUIC", func() {
|
||||
It("creates new TLS sessions with the right parameters", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
c := make(chan struct{})
|
||||
var cconn connection
|
||||
var version protocol.VersionNumber
|
||||
var conf *Config
|
||||
newTLSClientSession = func(
|
||||
connP connection,
|
||||
_ sessionRunner,
|
||||
tokenP []byte,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
params *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber, /* initial version */
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
cconn = connP
|
||||
version = versionP
|
||||
conf = configP
|
||||
close(c)
|
||||
// TODO: check connection IDs?
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run()
|
||||
return sess, nil
|
||||
}
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
||||
It("creates a new session when the server performs a retry", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) {
|
||||
go handler.handlePacket(&receivedPacket{
|
||||
header: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
Token: []byte("foobar"),
|
||||
DestConnectionID: id,
|
||||
OrigDestConnectionID: connID,
|
||||
},
|
||||
})
|
||||
})
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
cl.config = config
|
||||
run1 := make(chan error)
|
||||
sess1 := NewMockQuicSession(mockCtrl)
|
||||
sess1.EXPECT().run().DoAndReturn(func() error {
|
||||
return <-run1
|
||||
})
|
||||
sess1.EXPECT().destroy(errCloseSessionForRetry).Do(func(e error) {
|
||||
run1 <- e
|
||||
})
|
||||
sess2 := NewMockQuicSession(mockCtrl)
|
||||
sess2.EXPECT().run()
|
||||
sessions := make(chan quicSession, 2)
|
||||
sessions <- sess1
|
||||
sessions <- sess2
|
||||
newTLSClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ []byte,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber, /* initial version */
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
return <-sessions, nil
|
||||
}
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sessions).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("only accepts a single retry", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) {
|
||||
go handler.handlePacket(&receivedPacket{
|
||||
header: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
Token: []byte("foobar"),
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
DestConnectionID: id,
|
||||
OrigDestConnectionID: connID,
|
||||
Version: protocol.VersionTLS,
|
||||
},
|
||||
})
|
||||
}).AnyTimes()
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
cl.config = config
|
||||
|
||||
sessions := make(chan quicSession, 2)
|
||||
run := make(chan error)
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
c := make(chan struct{})
|
||||
var cconn connection
|
||||
var version protocol.VersionNumber
|
||||
var conf *Config
|
||||
newClientSession = func(
|
||||
connP connection,
|
||||
_ sessionRunner,
|
||||
tokenP []byte,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
params *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber, /* initial version */
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
cconn = connP
|
||||
version = versionP
|
||||
conf = configP
|
||||
close(c)
|
||||
// TODO: check connection IDs?
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run().DoAndReturn(func() error {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
Eventually(run).Should(Receive(&err))
|
||||
return err
|
||||
})
|
||||
sess.EXPECT().destroy(gomock.Any()).Do(func(e error) {
|
||||
run <- e
|
||||
})
|
||||
sessions <- sess
|
||||
doneErr := errors.New("nothing to do")
|
||||
sess = NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run().Return(doneErr)
|
||||
sessions <- sess
|
||||
sess.EXPECT().run()
|
||||
return sess, nil
|
||||
}
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
||||
newTLSClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ []byte,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber, /* initial version */
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
return <-sessions, nil
|
||||
}
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(MatchError(doneErr))
|
||||
Expect(sessions).To(BeEmpty())
|
||||
It("creates a new session when the server performs a retry", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) {
|
||||
go handler.handlePacket(&receivedPacket{
|
||||
header: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
Token: []byte("foobar"),
|
||||
DestConnectionID: id,
|
||||
OrigDestConnectionID: connID,
|
||||
},
|
||||
})
|
||||
})
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
cl.config = config
|
||||
run1 := make(chan error)
|
||||
sess1 := NewMockQuicSession(mockCtrl)
|
||||
sess1.EXPECT().run().DoAndReturn(func() error {
|
||||
return <-run1
|
||||
})
|
||||
sess1.EXPECT().destroy(errCloseSessionForRetry).Do(func(e error) {
|
||||
run1 <- e
|
||||
})
|
||||
sess2 := NewMockQuicSession(mockCtrl)
|
||||
sess2.EXPECT().run()
|
||||
sessions := make(chan quicSession, 2)
|
||||
sessions <- sess1
|
||||
sessions <- sess2
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
return <-sessions, nil
|
||||
}
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sessions).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("only accepts a single retry", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any()).Do(func(id protocol.ConnectionID, handler packetHandler) {
|
||||
go handler.handlePacket(&receivedPacket{
|
||||
header: &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeRetry,
|
||||
Token: []byte("foobar"),
|
||||
SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4},
|
||||
DestConnectionID: id,
|
||||
OrigDestConnectionID: connID,
|
||||
Version: protocol.VersionTLS,
|
||||
},
|
||||
})
|
||||
}).AnyTimes()
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
cl.config = config
|
||||
|
||||
sessions := make(chan quicSession, 2)
|
||||
run := make(chan error)
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run().DoAndReturn(func() error {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
Eventually(run).Should(Receive(&err))
|
||||
return err
|
||||
})
|
||||
sess.EXPECT().destroy(gomock.Any()).Do(func(e error) {
|
||||
run <- e
|
||||
})
|
||||
sessions <- sess
|
||||
doneErr := errors.New("nothing to do")
|
||||
sess = NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run().Return(doneErr)
|
||||
sessions <- sess
|
||||
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
return <-sessions, nil
|
||||
}
|
||||
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
|
||||
Expect(err).To(MatchError(doneErr))
|
||||
Expect(sessions).To(BeEmpty())
|
||||
})
|
||||
|
||||
Context("version negotiation", func() {
|
||||
|
@ -681,14 +635,16 @@ var _ = Describe("Client", func() {
|
|||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ []byte, // token
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ *handshake.TransportParameters,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) (quicSession, error) {
|
||||
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
|
@ -700,7 +656,7 @@ var _ = Describe("Client", func() {
|
|||
addr,
|
||||
"quic.clemente.io:1337",
|
||||
nil,
|
||||
&Config{Versions: supportedVersionsWithoutGQUIC44},
|
||||
&Config{},
|
||||
)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
@ -721,103 +677,6 @@ var _ = Describe("Client", func() {
|
|||
Expect(cl.versionNegotiated).To(BeTrue())
|
||||
})
|
||||
|
||||
It("changes the version after receiving a Version Negotiation Packet", func() {
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
phm.EXPECT().Add(connID, gomock.Any()).Times(2)
|
||||
cl.packetHandlers = phm
|
||||
|
||||
version1 := protocol.Version39
|
||||
version2 := protocol.Version39 + 1
|
||||
Expect(version2.UsesTLS()).To(BeFalse())
|
||||
sess1 := NewMockQuicSession(mockCtrl)
|
||||
run1 := make(chan struct{})
|
||||
sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion)
|
||||
sess1.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
|
||||
sess2 := NewMockQuicSession(mockCtrl)
|
||||
sess2.EXPECT().run()
|
||||
sessionChan := make(chan *MockQuicSession, 2)
|
||||
sessionChan <- sess1
|
||||
sessionChan <- sess2
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ utils.Logger,
|
||||
) (quicSession, error) {
|
||||
return <-sessionChan, nil
|
||||
}
|
||||
|
||||
cl.tlsConf = &tls.Config{}
|
||||
cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2}}
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cl.dial(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(dialed)
|
||||
}()
|
||||
Eventually(sessionChan).Should(HaveLen(1))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version2}))
|
||||
Eventually(sessionChan).Should(BeEmpty())
|
||||
})
|
||||
|
||||
It("only accepts one version negotiation packet", func() {
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
phm.EXPECT().Add(connID, gomock.Any()).Times(2)
|
||||
cl.packetHandlers = phm
|
||||
version1 := protocol.Version39
|
||||
version2 := protocol.Version39 + 1
|
||||
version3 := protocol.Version39 + 2
|
||||
Expect(version2.UsesTLS()).To(BeFalse())
|
||||
Expect(version3.UsesTLS()).To(BeFalse())
|
||||
sess1 := NewMockQuicSession(mockCtrl)
|
||||
run1 := make(chan struct{})
|
||||
sess1.EXPECT().run().Do(func() { <-run1 }).Return(errCloseSessionForNewVersion)
|
||||
sess1.EXPECT().destroy(errCloseSessionForNewVersion).Do(func(error) { close(run1) })
|
||||
sess2 := NewMockQuicSession(mockCtrl)
|
||||
sess2.EXPECT().run()
|
||||
sessionChan := make(chan *MockQuicSession, 2)
|
||||
sessionChan <- sess1
|
||||
sessionChan <- sess2
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.VersionNumber,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *tls.Config,
|
||||
_ *Config,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ utils.Logger,
|
||||
) (quicSession, error) {
|
||||
return <-sessionChan, nil
|
||||
}
|
||||
|
||||
cl.tlsConf = &tls.Config{}
|
||||
cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2, version3}}
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
err := cl.dial(context.Background())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(dialed)
|
||||
}()
|
||||
Eventually(sessionChan).Should(HaveLen(1))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version2}))
|
||||
Eventually(sessionChan).Should(BeEmpty())
|
||||
Expect(cl.version).To(Equal(version2))
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{version3}))
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(version2))
|
||||
})
|
||||
|
||||
It("errors if no matching version is found", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().destroy(qerr.InvalidVersion)
|
||||
|
@ -863,26 +722,8 @@ var _ = Describe("Client", func() {
|
|||
Expect(cl.GetVersion()).To(Equal(cl.version))
|
||||
})
|
||||
|
||||
It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
|
||||
cl.version = versionGQUICFrames
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
|
||||
cl.config = &Config{RequestConnectionIDOmission: false}
|
||||
hdr := &wire.Header{
|
||||
IsPublicHeader: true,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: 1,
|
||||
}
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
})
|
||||
Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation"))
|
||||
})
|
||||
|
||||
It("ignores packets with the wrong destination connection ID", func() {
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
|
||||
cl.version = versionIETFFrames
|
||||
cl.config = &Config{RequestConnectionIDOmission: false}
|
||||
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
Expect(connID).ToNot(Equal(connID2))
|
||||
hdr := &wire.Header{
|
||||
|
@ -890,7 +731,6 @@ var _ = Describe("Client", func() {
|
|||
SrcConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
|
@ -898,110 +738,4 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID)))
|
||||
})
|
||||
|
||||
It("creates new gQUIC sessions with the right parameters", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(gomock.Any(), gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil)
|
||||
|
||||
c := make(chan struct{})
|
||||
var cconn connection
|
||||
var hostname string
|
||||
var version protocol.VersionNumber
|
||||
var conf *Config
|
||||
newClientSession = func(
|
||||
connP connection,
|
||||
_ sessionRunner,
|
||||
versionP protocol.VersionNumber,
|
||||
connIDP protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
tlsConf *tls.Config,
|
||||
configP *Config,
|
||||
_ protocol.VersionNumber,
|
||||
_ []protocol.VersionNumber,
|
||||
_ utils.Logger,
|
||||
) (quicSession, error) {
|
||||
cconn = connP
|
||||
hostname = tlsConf.ServerName
|
||||
version = versionP
|
||||
conf = configP
|
||||
connID = connIDP
|
||||
close(c)
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run()
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
config := &Config{Versions: supportedVersionsWithoutGQUIC44}
|
||||
_, err := Dial(
|
||||
packetConn,
|
||||
addr,
|
||||
"quic.clemente.io:1337",
|
||||
nil,
|
||||
config,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Eventually(c).Should(BeClosed())
|
||||
Expect(cconn.(*conn).pconn).To(Equal(packetConn))
|
||||
Expect(hostname).To(Equal("quic.clemente.io"))
|
||||
Expect(version).To(Equal(config.Versions[0]))
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
||||
Context("Public Reset handling", func() {
|
||||
var (
|
||||
pr []byte
|
||||
hdr *wire.Header
|
||||
hdrLen int
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
cl.config = &Config{}
|
||||
|
||||
pr = wire.WritePublicReset(cl.destConnID, 1, 0)
|
||||
r := bytes.NewReader(pr)
|
||||
iHdr, err := wire.ParseInvariantHeader(r, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdr, err = iHdr.Parse(r, protocol.PerspectiveServer, versionGQUICFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
hdrLen = r.Len()
|
||||
})
|
||||
|
||||
It("closes the session when receiving a Public Reset", func() {
|
||||
cl.version = versionGQUICFrames
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().closeRemote(gomock.Any()).Do(func(err error) {
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset))
|
||||
})
|
||||
cl.session = sess
|
||||
cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: pr[len(pr)-hdrLen:],
|
||||
})
|
||||
})
|
||||
|
||||
It("ignores Public Resets from the wrong remote address", func() {
|
||||
cl.version = versionGQUICFrames
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls
|
||||
spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678}
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: spoofedAddr,
|
||||
header: hdr,
|
||||
data: pr[len(pr)-hdrLen:],
|
||||
})
|
||||
Expect(err).To(MatchError("Received a spoofed Public Reset"))
|
||||
})
|
||||
|
||||
It("ignores unparseable Public Resets", func() {
|
||||
cl.version = versionGQUICFrames
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: pr[len(pr)-hdrLen : len(pr)-5], // cut off the last 5 bytes
|
||||
})
|
||||
Expect(err.Error()).To(ContainSubstring("Received a Public Reset. An error occurred parsing the packet"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -7,15 +7,12 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
func main() {
|
||||
verbose := flag.Bool("v", false, "verbose")
|
||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
quiet := flag.Bool("q", false, "don't print the data")
|
||||
flag.Parse()
|
||||
urls := flag.Args()
|
||||
|
@ -29,14 +26,7 @@ func main() {
|
|||
}
|
||||
logger.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||
}
|
||||
|
||||
roundTripper := &h2quic.RoundTripper{
|
||||
QuicConfig: &quic.Config{Versions: versions},
|
||||
}
|
||||
roundTripper := &h2quic.RoundTripper{}
|
||||
defer roundTripper.Close()
|
||||
hclient := &http.Client{
|
||||
Transport: roundTripper,
|
||||
|
|
|
@ -17,9 +17,7 @@ import (
|
|||
|
||||
_ "net/http/pprof"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -123,7 +121,6 @@ func main() {
|
|||
certPath := flag.String("certpath", getBuildDir(), "certificate directory")
|
||||
www := flag.String("www", "/var/www", "www data")
|
||||
tcp := flag.Bool("tcp", false, "also listen on TCP")
|
||||
tls := flag.Bool("tls", false, "activate support for IETF QUIC (work in progress)")
|
||||
flag.Parse()
|
||||
|
||||
logger := utils.DefaultLogger
|
||||
|
@ -135,11 +132,6 @@ func main() {
|
|||
}
|
||||
logger.SetLogTimeFormat("")
|
||||
|
||||
versions := protocol.SupportedVersions
|
||||
if *tls {
|
||||
versions = append([]protocol.VersionNumber{protocol.VersionTLS}, versions...)
|
||||
}
|
||||
|
||||
certFile := *certPath + "/fullchain.pem"
|
||||
keyFile := *certPath + "/privkey.pem"
|
||||
|
||||
|
@ -159,8 +151,7 @@ func main() {
|
|||
err = h2quic.ListenAndServe(bCap, certFile, keyFile, nil)
|
||||
} else {
|
||||
server := h2quic.Server{
|
||||
Server: &http.Server{Addr: bCap},
|
||||
QuicConfig: &quic.Config{Versions: versions},
|
||||
Server: &http.Server{Addr: bCap},
|
||||
}
|
||||
err = server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
|
|
|
@ -52,10 +52,7 @@ type client struct {
|
|||
|
||||
var _ http.RoundTripper = &client{}
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
RequestConnectionIDOmission: true,
|
||||
KeepAlive: true,
|
||||
}
|
||||
var defaultQuicConfig = &quic.Config{KeepAlive: true}
|
||||
|
||||
// newClient creates a new client
|
||||
func newClient(
|
||||
|
|
|
@ -1,210 +0,0 @@
|
|||
package chrome_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gexec"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
dataLen = 500 * 1024 // 500 KB
|
||||
dataLongLen = 50 * 1024 * 1024 // 50 MB
|
||||
)
|
||||
|
||||
var (
|
||||
nFilesUploaded int32 // should be used atomically
|
||||
doneCalled utils.AtomicBool
|
||||
)
|
||||
|
||||
func TestChrome(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Chrome Suite")
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Requires the len & num GET parameters, e.g. /uploadtest?len=100&num=1
|
||||
http.HandleFunc("/uploadtest", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
response := uploadHTML
|
||||
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
|
||||
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
|
||||
_, err := io.WriteString(w, response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
// Requires the len & num GET parameters, e.g. /downloadtest?len=100&num=1
|
||||
http.HandleFunc("/downloadtest", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
response := downloadHTML
|
||||
response = strings.Replace(response, "LENGTH", r.URL.Query().Get("len"), -1)
|
||||
response = strings.Replace(response, "NUM", r.URL.Query().Get("num"), -1)
|
||||
_, err := io.WriteString(w, response)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
http.HandleFunc("/uploadhandler", func(w http.ResponseWriter, r *http.Request) {
|
||||
defer GinkgoRecover()
|
||||
|
||||
l, err := strconv.Atoi(r.URL.Query().Get("len"))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
defer r.Body.Close()
|
||||
actual, err := ioutil.ReadAll(r.Body)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(bytes.Equal(actual, testserver.GeneratePRData(l))).To(BeTrue())
|
||||
|
||||
atomic.AddInt32(&nFilesUploaded, 1)
|
||||
})
|
||||
|
||||
http.HandleFunc("/done", func(w http.ResponseWriter, r *http.Request) {
|
||||
doneCalled.Set(true)
|
||||
})
|
||||
}
|
||||
|
||||
var _ = AfterEach(func() {
|
||||
testserver.StopQuicServer()
|
||||
|
||||
atomic.StoreInt32(&nFilesUploaded, 0)
|
||||
doneCalled.Set(false)
|
||||
})
|
||||
|
||||
func getChromePath() string {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
|
||||
}
|
||||
if path, err := exec.LookPath("google-chrome"); err == nil {
|
||||
return path
|
||||
}
|
||||
if path, err := exec.LookPath("chromium-browser"); err == nil {
|
||||
return path
|
||||
}
|
||||
Fail("No Chrome executable found.")
|
||||
return ""
|
||||
}
|
||||
|
||||
func chromeTest(version protocol.VersionNumber, url string, blockUntilDone func()) {
|
||||
userDataDir, err := ioutil.TempDir("", "quic-go-test-chrome-dir")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer os.RemoveAll(userDataDir)
|
||||
path := getChromePath()
|
||||
args := []string{
|
||||
"--disable-gpu",
|
||||
"--no-first-run=true",
|
||||
"--no-default-browser-check=true",
|
||||
"--user-data-dir=" + userDataDir,
|
||||
"--enable-quic=true",
|
||||
"--no-proxy-server=true",
|
||||
"--no-sandbox",
|
||||
"--origin-to-force-quic-on=quic.clemente.io:443",
|
||||
fmt.Sprintf(`--host-resolver-rules=MAP quic.clemente.io:443 127.0.0.1:%s`, testserver.Port()),
|
||||
fmt.Sprintf("--quic-version=QUIC_VERSION_%s", version.ToAltSvc()),
|
||||
url,
|
||||
}
|
||||
utils.DefaultLogger.Infof("Running chrome: %s '%s'", getChromePath(), strings.Join(args, "' '"))
|
||||
command := exec.Command(path, args...)
|
||||
session, err := gexec.Start(command, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
blockUntilDone()
|
||||
}
|
||||
|
||||
func waitForDone() {
|
||||
Eventually(func() bool { return doneCalled.Get() }, 60).Should(BeTrue())
|
||||
}
|
||||
|
||||
func waitForNUploaded(expected int) func() {
|
||||
return func() {
|
||||
Eventually(func() int32 {
|
||||
return atomic.LoadInt32(&nFilesUploaded)
|
||||
}, 60).Should(BeEquivalentTo(expected))
|
||||
}
|
||||
}
|
||||
|
||||
const commonJS = `
|
||||
var buf = new ArrayBuffer(LENGTH);
|
||||
var prng = new Uint8Array(buf);
|
||||
var seed = 1;
|
||||
for (var i = 0; i < LENGTH; i++) {
|
||||
// https://en.wikipedia.org/wiki/Lehmer_random_number_generator
|
||||
seed = seed * 48271 % 2147483647;
|
||||
prng[i] = seed;
|
||||
}
|
||||
`
|
||||
|
||||
const uploadHTML = `
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
console.log("Running DL test...");
|
||||
|
||||
` + commonJS + `
|
||||
for (var i = 0; i < NUM; i++) {
|
||||
var req = new XMLHttpRequest();
|
||||
req.open("POST", "/uploadhandler?len=" + LENGTH, true);
|
||||
req.send(buf);
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
const downloadHTML = `
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
console.log("Running DL test...");
|
||||
` + commonJS + `
|
||||
|
||||
function verify(data) {
|
||||
if (data.length !== LENGTH) return false;
|
||||
for (var i = 0; i < LENGTH; i++) {
|
||||
if (data[i] !== prng[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
var nOK = 0;
|
||||
for (var i = 0; i < NUM; i++) {
|
||||
let req = new XMLHttpRequest();
|
||||
req.responseType = "arraybuffer";
|
||||
req.open("POST", "/prdata?len=" + LENGTH, true);
|
||||
req.onreadystatechange = function () {
|
||||
if (req.readyState === XMLHttpRequest.DONE && req.status === 200) {
|
||||
if (verify(new Uint8Array(req.response))) {
|
||||
nOK++;
|
||||
if (nOK === NUM) {
|
||||
console.log("Done :)");
|
||||
var reqDone = new XMLHttpRequest();
|
||||
reqDone.open("GET", "/done");
|
||||
reqDone.send();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
req.send();
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`
|
|
@ -1,71 +0,0 @@
|
|||
package chrome_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
)
|
||||
|
||||
var _ = Describe("Chrome tests", func() {
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
Context(fmt.Sprintf("with version %s", version), func() {
|
||||
JustBeforeEach(func() {
|
||||
testserver.StartQuicServer([]protocol.VersionNumber{version})
|
||||
})
|
||||
|
||||
It("downloads a small file", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLen),
|
||||
waitForDone,
|
||||
)
|
||||
})
|
||||
|
||||
It("downloads a large file", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/downloadtest?num=1&len=%d", dataLongLen),
|
||||
waitForDone,
|
||||
)
|
||||
})
|
||||
|
||||
It("loads a large number of files", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
"https://quic.clemente.io/downloadtest?num=4&len=100",
|
||||
waitForDone,
|
||||
)
|
||||
})
|
||||
|
||||
It("uploads a small file", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLen),
|
||||
waitForNUploaded(1),
|
||||
)
|
||||
})
|
||||
|
||||
It("uploads a large file", func() {
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=1&len=%d", dataLongLen),
|
||||
waitForNUploaded(1),
|
||||
)
|
||||
})
|
||||
|
||||
It("uploads many small files", func() {
|
||||
num := protocol.DefaultMaxIncomingStreams + 20
|
||||
chromeTest(
|
||||
version,
|
||||
fmt.Sprintf("https://quic.clemente.io/uploadtest?num=%d&len=%d", num, dataLen),
|
||||
waitForNUploaded(num),
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -1,137 +0,0 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
. "github.com/onsi/gomega/gbytes"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.DirectionOutgoing, quicproxy.DirectionBoth}
|
||||
|
||||
var _ = Describe("Drop tests", func() {
|
||||
var proxy *quicproxy.QuicProxy
|
||||
|
||||
startProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) {
|
||||
var err error
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:0", version, &quicproxy.Opts{
|
||||
RemoteAddr: "localhost:" + testserver.Port(),
|
||||
DropPacket: dropCallback,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
downloadFile := func(version protocol.VersionNumber) {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
}
|
||||
|
||||
downloadHello := func(version protocol.VersionNumber) {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
||||
"https://quic.clemente.io/hello",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(session.Out).To(Say(":status 200"))
|
||||
Expect(session.Out).To(Say("body: Hello, World!\n"))
|
||||
}
|
||||
|
||||
deterministicDropper := func(p, interval, dropInARow uint64) bool {
|
||||
return (p % interval) < dropInARow
|
||||
}
|
||||
|
||||
stochasticDropper := func(freq int) bool {
|
||||
return mrand.Int63n(int64(freq)) == 0
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(proxy.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
Context("during the crypto handshake", func() {
|
||||
for _, d := range directions {
|
||||
direction := d
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p == 1 && d.Is(direction)
|
||||
}, version)
|
||||
downloadHello(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p == 2 && d.Is(direction)
|
||||
}, version)
|
||||
downloadHello(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("establishes a connection when 1/5 of the packets are lost in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return d.Is(direction) && stochasticDropper(5)
|
||||
}, version)
|
||||
downloadHello(version)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
Context("after the crypto handshake", func() {
|
||||
for _, d := range directions {
|
||||
direction := d
|
||||
|
||||
It(fmt.Sprintf("downloads a file when every 5th packet is dropped in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p >= 10 && d.Is(direction) && deterministicDropper(p, 5, 1)
|
||||
}, version)
|
||||
downloadFile(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("downloads a file when 1/5th of all packet are dropped randomly in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p >= 10 && d.Is(direction) && stochasticDropper(5)
|
||||
}, version)
|
||||
downloadFile(version)
|
||||
})
|
||||
|
||||
It(fmt.Sprintf("downloads a file when 10 packets every 100 packet are dropped in %s direction", d), func() {
|
||||
startProxy(func(d quicproxy.Direction, p uint64) bool {
|
||||
return p >= 10 && d.Is(direction) && deterministicDropper(p, 100, 10)
|
||||
}, version)
|
||||
downloadFile(version)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -1,45 +0,0 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-go/integrationtests/tools/testlog"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
clientPath string
|
||||
serverPath string
|
||||
)
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "GQuic Tests Suite")
|
||||
}
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
rand.Seed(GinkgoRandomSeed())
|
||||
})
|
||||
|
||||
var _ = JustBeforeEach(func() {
|
||||
testserver.StartQuicServer(nil)
|
||||
})
|
||||
|
||||
var _ = AfterEach(testserver.StopQuicServer)
|
||||
|
||||
func init() {
|
||||
_, thisfile, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
panic("Failed to get current path")
|
||||
}
|
||||
clientPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/client-%s-debug", runtime.GOOS))
|
||||
serverPath = filepath.Join(thisfile, fmt.Sprintf("../../../../quic-clients/server-%s-debug", runtime.GOOS))
|
||||
}
|
|
@ -1,98 +0,0 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
. "github.com/onsi/gomega/gbytes"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
var _ = Describe("Integration tests", func() {
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
It("gets a simple file", func() {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+testserver.Port(),
|
||||
"https://quic.clemente.io/hello",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 5).Should(Exit(0))
|
||||
Expect(session.Out).To(Say(":status 200"))
|
||||
Expect(session.Out).To(Say("body: Hello, World!\n"))
|
||||
})
|
||||
|
||||
It("posts and reads a body", func() {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+testserver.Port(),
|
||||
"--body=foo",
|
||||
"https://quic.clemente.io/echo",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 5).Should(Exit(0))
|
||||
Expect(session.Out).To(Say(":status 200"))
|
||||
Expect(session.Out).To(Say("body: foo\n"))
|
||||
})
|
||||
|
||||
It("gets a file", func() {
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+testserver.Port(),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 10).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("gets many copies of a file in parallel", func() {
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer GinkgoRecover()
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+testserver.Port(),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -1,95 +0,0 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
// get a random duration between min and max
|
||||
func getRandomDuration(min, max time.Duration) time.Duration {
|
||||
return min + time.Duration(rand.Int63n(int64(max-min)))
|
||||
}
|
||||
|
||||
var _ = Describe("Random Duration Generator", func() {
|
||||
It("gets a random RTT", func() {
|
||||
var min time.Duration = time.Hour
|
||||
var max time.Duration
|
||||
|
||||
var sum time.Duration
|
||||
rep := 10000
|
||||
for i := 0; i < rep; i++ {
|
||||
val := getRandomDuration(100*time.Millisecond, 500*time.Millisecond)
|
||||
sum += val
|
||||
if val < min {
|
||||
min = val
|
||||
}
|
||||
if val > max {
|
||||
max = val
|
||||
}
|
||||
}
|
||||
avg := sum / time.Duration(rep)
|
||||
Expect(avg).To(BeNumerically("~", 300*time.Millisecond, 5*time.Millisecond))
|
||||
Expect(min).To(BeNumerically(">=", 100*time.Millisecond))
|
||||
Expect(min).To(BeNumerically("<", 105*time.Millisecond))
|
||||
Expect(max).To(BeNumerically(">", 495*time.Millisecond))
|
||||
Expect(max).To(BeNumerically("<=", 500*time.Millisecond))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Random RTT", func() {
|
||||
var proxy *quicproxy.QuicProxy
|
||||
|
||||
runRTTTest := func(minRtt, maxRtt time.Duration, version protocol.VersionNumber) {
|
||||
var err error
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
|
||||
RemoteAddr: "localhost:" + testserver.Port(),
|
||||
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
|
||||
return getRandomDuration(minRtt, maxRtt)
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
err := proxy.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(time.Millisecond)
|
||||
})
|
||||
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
It("gets a file a random RTT between 10ms and 30ms", func() {
|
||||
runRTTTest(10*time.Millisecond, 30*time.Millisecond, version)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -1,66 +0,0 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
_ "github.com/lucas-clemente/quic-clients" // download clients
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
var _ = Describe("non-zero RTT", func() {
|
||||
var proxy *quicproxy.QuicProxy
|
||||
|
||||
runRTTTest := func(rtt time.Duration, version protocol.VersionNumber) {
|
||||
var err error
|
||||
proxy, err = quicproxy.NewQuicProxy("localhost:", version, &quicproxy.Opts{
|
||||
RemoteAddr: "localhost:" + testserver.Port(),
|
||||
DelayPacket: func(_ quicproxy.Direction, _ uint64) time.Duration {
|
||||
return rtt / 2
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
command := exec.Command(
|
||||
clientPath,
|
||||
"--quic-version="+version.ToAltSvc(),
|
||||
"--host=127.0.0.1",
|
||||
"--port="+strconv.Itoa(proxy.LocalPort()),
|
||||
"https://quic.clemente.io/prdata",
|
||||
)
|
||||
|
||||
session, err := Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
defer session.Kill()
|
||||
Eventually(session, 20).Should(Exit(0))
|
||||
Expect(bytes.Contains(session.Out.Contents(), testserver.PRData)).To(BeTrue())
|
||||
}
|
||||
|
||||
AfterEach(func() {
|
||||
err := proxy.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
time.Sleep(time.Millisecond)
|
||||
})
|
||||
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
roundTrips := [...]int{10, 50, 100, 200}
|
||||
for _, rtt := range roundTrips {
|
||||
It(fmt.Sprintf("gets a 500kB file with %dms RTT", rtt), func() {
|
||||
runRTTTest(time.Duration(rtt)*time.Millisecond, version)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
|
@ -1,218 +0,0 @@
|
|||
package gquic_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
mrand "math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
quic "github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
. "github.com/onsi/gomega/gexec"
|
||||
)
|
||||
|
||||
var _ = Describe("Server tests", func() {
|
||||
for i := range protocol.SupportedVersions {
|
||||
version := protocol.SupportedVersions[i]
|
||||
|
||||
var (
|
||||
serverPort string
|
||||
tmpDir string
|
||||
session *Session
|
||||
client *http.Client
|
||||
)
|
||||
|
||||
generateCA := func() (*rsa.PrivateKey, *x509.Certificate) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
templateRoot := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, templateRoot, templateRoot, &key.PublicKey, key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return key, cert
|
||||
}
|
||||
|
||||
// prepare the file such that it can be by the quic_server
|
||||
// some HTTP headers neeed to be prepended, see https://www.chromium.org/quic/playing-with-quic
|
||||
createDownloadFile := func(filename string, data []byte) {
|
||||
dataDir := filepath.Join(tmpDir, "quic.clemente.io")
|
||||
err := os.Mkdir(dataDir, 0777)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
f, err := os.Create(filepath.Join(dataDir, filename))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer f.Close()
|
||||
_, err = f.Write([]byte("HTTP/1.1 200 OK\n"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write([]byte("Content-Type: text/html\n"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write([]byte("X-Original-Url: https://quic.clemente.io:" + serverPort + "/" + filename + "\n"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write([]byte("Content-Length: " + strconv.Itoa(len(data)) + "\n\n"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
// download files must be create *before* the quic_server is started
|
||||
// the quic_server reads its data dir on startup, and only serves those files that were already present then
|
||||
startServer := func(version protocol.VersionNumber) {
|
||||
defer GinkgoRecover()
|
||||
var err error
|
||||
command := exec.Command(
|
||||
serverPath,
|
||||
"--quic_response_cache_dir="+filepath.Join(tmpDir, "quic.clemente.io"),
|
||||
"--key_file="+filepath.Join(tmpDir, "key.pkcs8"),
|
||||
"--certificate_file="+filepath.Join(tmpDir, "cert.pem"),
|
||||
"--quic-version="+strconv.Itoa(int(version)),
|
||||
"--port="+serverPort,
|
||||
)
|
||||
session, err = Start(command, nil, GinkgoWriter)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
|
||||
stopServer := func() {
|
||||
session.Kill()
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
serverPort = strconv.Itoa(20000 + int(mrand.Int31n(10000)))
|
||||
|
||||
var err error
|
||||
tmpDir, err = ioutil.TempDir("", "quic-server-certs")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// generate an RSA key pair for the server
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// save the private key in PKCS8 format to disk (required by quic_server)
|
||||
pkcs8key, err := asn1.Marshal(struct { // copied from the x509 package
|
||||
Version int
|
||||
Algo pkix.AlgorithmIdentifier
|
||||
PrivateKey []byte
|
||||
}{
|
||||
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
|
||||
Algo: pkix.AlgorithmIdentifier{
|
||||
Algorithm: asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1},
|
||||
Parameters: asn1.RawValue{Tag: 5},
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
f, err := os.Create(filepath.Join(tmpDir, "key.pkcs8"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = f.Write(pkcs8key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
f.Close()
|
||||
|
||||
// generate a Certificate Authority
|
||||
// this CA is used to sign the server's key
|
||||
// it is set as a valid CA in the QUIC client
|
||||
rootKey, CACert := generateCA()
|
||||
// generate the server certificate
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-30 * time.Minute),
|
||||
NotAfter: time.Now().Add(30 * time.Minute),
|
||||
Subject: pkix.Name{CommonName: "quic.clemente.io"},
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, CACert, &key.PublicKey, rootKey)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// save the certificate to disk
|
||||
certOut, err := os.Create(filepath.Join(tmpDir, "cert.pem"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
certOut.Close()
|
||||
|
||||
// prepare the h2quic.client
|
||||
certPool := x509.NewCertPool()
|
||||
certPool.AddCert(CACert)
|
||||
client = &http.Client{
|
||||
Transport: &h2quic.RoundTripper{
|
||||
TLSClientConfig: &tls.Config{RootCAs: certPool},
|
||||
QuicConfig: &quic.Config{
|
||||
Versions: []protocol.VersionNumber{version},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(tmpDir).ToNot(BeEmpty())
|
||||
err := os.RemoveAll(tmpDir)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tmpDir = ""
|
||||
})
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
It("downloads a hello", func() {
|
||||
data := []byte("Hello world!\n")
|
||||
createDownloadFile("hello", data)
|
||||
|
||||
startServer(version)
|
||||
defer stopServer()
|
||||
|
||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/hello")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal(data))
|
||||
})
|
||||
|
||||
It("downloads a small file", func() {
|
||||
createDownloadFile("file.dat", testserver.PRData)
|
||||
|
||||
startServer(version)
|
||||
defer stopServer()
|
||||
|
||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 5*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal(testserver.PRData))
|
||||
})
|
||||
|
||||
It("downloads a large file", func() {
|
||||
createDownloadFile("file.dat", testserver.PRDataLong)
|
||||
|
||||
startServer(version)
|
||||
defer stopServer()
|
||||
|
||||
rsp, err := client.Get("https://quic.clemente.io:" + serverPort + "/file.dat")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.StatusCode).To(Equal(200))
|
||||
body, err := ioutil.ReadAll(gbytes.TimeoutReader(rsp.Body, 20*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(body).To(Equal(testserver.PRDataLong))
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go/h2quic"
|
||||
"github.com/lucas-clemente/quic-go/integrationtests/tools/testserver"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/onsi/gomega/gbytes"
|
||||
|
@ -21,8 +22,7 @@ import (
|
|||
var _ = Describe("Client tests", func() {
|
||||
var client *http.Client
|
||||
|
||||
// also run some tests with the TLS handshake
|
||||
versions := append(protocol.SupportedVersions, protocol.VersionTLS)
|
||||
versions := protocol.SupportedVersions
|
||||
|
||||
BeforeEach(func() {
|
||||
err := os.Setenv("HOSTALIASES", "quic.clemente.io 127.0.0.1")
|
||||
|
|
|
@ -60,42 +60,32 @@ var _ = Describe("Connection ID lengths tests", func() {
|
|||
Expect(data).To(Equal(testserver.PRData))
|
||||
}
|
||||
|
||||
Context("IETF QUIC", func() {
|
||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
||||
serverConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
clientConf := &quic.Config{
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
||||
serverConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
clientConf := &quic.Config{
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), clientConf)
|
||||
})
|
||||
|
||||
It("downloads a file when both client and server use a random connection ID length", func() {
|
||||
serverConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
clientConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), clientConf)
|
||||
})
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), clientConf)
|
||||
})
|
||||
|
||||
Context("gQUIC", func() {
|
||||
It("downloads a file using a 0-byte connection ID for the client", func() {
|
||||
ln := runServer(&quic.Config{})
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), &quic.Config{RequestConnectionIDOmission: true})
|
||||
})
|
||||
It("downloads a file when both client and server use a random connection ID length", func() {
|
||||
serverConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
clientConf := &quic.Config{
|
||||
ConnectionIDLength: randomConnIDLen(),
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
|
||||
ln := runServer(serverConf)
|
||||
defer ln.Close()
|
||||
runClient(ln.Addr(), clientConf)
|
||||
})
|
||||
})
|
||||
|
|
|
@ -87,128 +87,57 @@ var _ = Describe("Handshake RTT tests", func() {
|
|||
expectDurationInRTTs(1)
|
||||
})
|
||||
|
||||
Context("gQUIC", func() {
|
||||
// 1 RTT for verifying the source address
|
||||
// 1 RTT to become secure
|
||||
// 1 RTT to become forward-secure
|
||||
It("is forward-secure after 3 RTTs", func() {
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(3)
|
||||
})
|
||||
var clientConfig *quic.Config
|
||||
var clientTLSConfig *tls.Config
|
||||
|
||||
It("does version negotiation in 1 RTT, IETF QUIC => gQUIC", func() {
|
||||
clientConfig := &quic.Config{
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS, protocol.SupportedVersions[0]},
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(4)
|
||||
})
|
||||
|
||||
It("is forward-secure after 2 RTTs when the server doesn't require a Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return true
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
})
|
||||
|
||||
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return false
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
|
||||
})
|
||||
|
||||
It("doesn't complete the handshake when the handshake timeout is too short", func() {
|
||||
serverConfig.HandshakeTimeout = 2 * rtt
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
||||
// 2 RTTs during the timeout
|
||||
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
|
||||
expectDurationInRTTs(3)
|
||||
})
|
||||
BeforeEach(func() {
|
||||
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
|
||||
clientConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
clientTLSConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "quic.clemente.io",
|
||||
}
|
||||
})
|
||||
|
||||
Context("IETF QUIC", func() {
|
||||
var clientConfig *quic.Config
|
||||
var clientTLSConfig *tls.Config
|
||||
// 1 RTT for verifying the source address
|
||||
// 1 RTT for the TLS handshake
|
||||
It("is forward-secure after 2 RTTs", func() {
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
serverConfig.Versions = []protocol.VersionNumber{protocol.VersionTLS}
|
||||
clientConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
clientTLSConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "quic.clemente.io",
|
||||
}
|
||||
})
|
||||
It("is forward-secure after 1 RTTs when the server doesn't require a Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return true
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(1)
|
||||
})
|
||||
|
||||
// 1 RTT for verifying the source address
|
||||
// 1 RTT for the TLS handshake
|
||||
It("is forward-secure after 2 RTTs", func() {
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(2)
|
||||
})
|
||||
|
||||
It("does version negotiation in 1 RTT, gQUIC => IETF QUIC", func() {
|
||||
clientConfig.Versions = []protocol.VersionNumber{protocol.SupportedVersions[0], protocol.VersionTLS}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(3)
|
||||
})
|
||||
|
||||
It("is forward-secure after 1 RTTs when the server doesn't require a Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return true
|
||||
}
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectDurationInRTTs(1)
|
||||
})
|
||||
|
||||
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return false
|
||||
}
|
||||
clientConfig.HandshakeTimeout = 500 * time.Millisecond
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
||||
})
|
||||
It("doesn't complete the handshake when the server never accepts the Cookie", func() {
|
||||
serverConfig.AcceptCookie = func(_ net.Addr, _ *quic.Cookie) bool {
|
||||
return false
|
||||
}
|
||||
clientConfig.HandshakeTimeout = 500 * time.Millisecond
|
||||
runServerAndProxy()
|
||||
_, err := quic.DialAddr(
|
||||
proxy.LocalAddr().String(),
|
||||
clientTLSConfig,
|
||||
clientConfig,
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
||||
})
|
||||
})
|
||||
|
|
|
@ -91,7 +91,7 @@ var _ = Describe("Handshake tests", func() {
|
|||
})
|
||||
|
||||
Context("Certifiate validation", func() {
|
||||
for _, v := range []protocol.VersionNumber{protocol.Version39, protocol.VersionTLS} {
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("using %s", version), func() {
|
||||
|
|
|
@ -18,15 +18,9 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Multiplexing", func() {
|
||||
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
version := v
|
||||
|
||||
// gQUIC 44 uses 0 byte connection IDs for packets sent to the client
|
||||
// It's not possible to do demultiplexing.
|
||||
if v == protocol.Version44 {
|
||||
continue
|
||||
}
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
runServer := func(ln quic.Listener) {
|
||||
go func() {
|
||||
|
@ -143,10 +137,6 @@ var _ = Describe("Multiplexing", func() {
|
|||
|
||||
Context("multiplexing server and client on the same conn", func() {
|
||||
It("connects to itself", func() {
|
||||
if version != protocol.VersionTLS {
|
||||
Skip("Connecting to itself only works with IETF QUIC.")
|
||||
}
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("non-zero RTT", func() {
|
||||
for _, v := range append(protocol.SupportedVersions, protocol.VersionTLS) {
|
||||
for _, v := range protocol.SupportedVersions {
|
||||
version := v
|
||||
|
||||
Context(fmt.Sprintf("with QUIC version %s", version), func() {
|
||||
|
|
16
interface.go
16
interface.go
|
@ -16,15 +16,6 @@ type StreamID = protocol.StreamID
|
|||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
const (
|
||||
// VersionGQUIC39 is gQUIC version 39.
|
||||
VersionGQUIC39 = protocol.Version39
|
||||
// VersionGQUIC43 is gQUIC version 43.
|
||||
VersionGQUIC43 = protocol.Version43
|
||||
// VersionGQUIC44 is gQUIC version 44.
|
||||
VersionGQUIC44 = protocol.Version44
|
||||
)
|
||||
|
||||
// A Cookie can be used to verify the ownership of the client address.
|
||||
type Cookie = handshake.Cookie
|
||||
|
||||
|
@ -164,11 +155,7 @@ type Config struct {
|
|||
// If not set, it uses all versions available.
|
||||
// Warning: This API should not be considered stable and will change soon.
|
||||
Versions []VersionNumber
|
||||
// Ask the server to omit the connection ID sent in the Public Header.
|
||||
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
|
||||
// Currently only valid for the client.
|
||||
RequestConnectionIDOmission bool
|
||||
// The length of the connection ID in bytes. Only valid for IETF QUIC.
|
||||
// The length of the connection ID in bytes.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If not set, the interpretation depends on where the Config is used:
|
||||
// If used for dialing an address, a 0 byte connection ID will be used.
|
||||
|
@ -201,7 +188,6 @@ type Config struct {
|
|||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
MaxIncomingStreams int
|
||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||
// This value doesn't have any effect in Google QUIC.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
// Values larger than 65535 (math.MaxUint16) are invalid.
|
||||
|
|
|
@ -27,7 +27,6 @@ type SentPacketHandler interface {
|
|||
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
|
||||
ShouldSendNumPackets() int
|
||||
|
||||
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
DequeuePacketForRetransmission() *Packet
|
||||
DequeueProbePacket() (*Packet, error)
|
||||
|
|
|
@ -16,8 +16,6 @@ func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
|
|||
// IsFrameRetransmittable returns true if the frame should be retransmitted.
|
||||
func IsFrameRetransmittable(f wire.Frame) bool {
|
||||
switch f.(type) {
|
||||
case *wire.StopWaitingFrame:
|
||||
return false
|
||||
case *wire.AckFrame:
|
||||
return false
|
||||
default:
|
||||
|
|
|
@ -11,10 +11,8 @@ import (
|
|||
var _ = Describe("retransmittable frames", func() {
|
||||
for fl, el := range map[wire.Frame]bool{
|
||||
&wire.AckFrame{}: false,
|
||||
&wire.StopWaitingFrame{}: false,
|
||||
&wire.BlockedFrame{}: true,
|
||||
&wire.ConnectionCloseFrame{}: true,
|
||||
&wire.GoawayFrame{}: true,
|
||||
&wire.PingFrame{}: true,
|
||||
&wire.RstStreamFrame{}: true,
|
||||
&wire.StreamFrame{}: true,
|
||||
|
|
|
@ -45,8 +45,7 @@ type sentPacketHandler struct {
|
|||
lowestPacketNotConfirmedAcked protocol.PacketNumber
|
||||
largestSentBeforeRTO protocol.PacketNumber
|
||||
|
||||
packetHistory *sentPacketHistory
|
||||
stopWaitingManager stopWaitingManager
|
||||
packetHistory *sentPacketHistory
|
||||
|
||||
retransmissionQueue []*Packet
|
||||
|
||||
|
@ -90,12 +89,11 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
|
|||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
packetHistory: newSentPacketHistory(),
|
||||
stopWaitingManager: stopWaitingManager{},
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
packetHistory: newSentPacketHistory(),
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,15 +108,13 @@ func (h *sentPacketHandler) SetHandshakeComplete() {
|
|||
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
|
||||
var queue []*Packet
|
||||
for _, packet := range h.retransmissionQueue {
|
||||
if packet.EncryptionLevel == protocol.EncryptionForwardSecure ||
|
||||
packet.EncryptionLevel == protocol.Encryption1RTT {
|
||||
if packet.EncryptionLevel == protocol.Encryption1RTT {
|
||||
queue = append(queue, packet)
|
||||
}
|
||||
}
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.EncryptionLevel != protocol.EncryptionForwardSecure &&
|
||||
p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
if p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
|
@ -169,8 +165,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt
|
|||
isRetransmittable := len(packet.Frames) != 0
|
||||
|
||||
if isRetransmittable {
|
||||
if packet.EncryptionLevel != protocol.EncryptionForwardSecure &&
|
||||
packet.EncryptionLevel != protocol.Encryption1RTT {
|
||||
if packet.EncryptionLevel != protocol.Encryption1RTT {
|
||||
h.lastSentHandshakePacketTime = packet.SendTime
|
||||
}
|
||||
h.lastSentRetransmittablePacketTime = packet.SendTime
|
||||
|
@ -217,12 +212,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
|
||||
priorInFlight := h.bytesInFlight
|
||||
for _, p := range ackedPackets {
|
||||
// TODO(#1534): also check the encryption level for IETF QUIC
|
||||
if !h.version.UsesTLS() {
|
||||
if encLevel < p.EncryptionLevel {
|
||||
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||
}
|
||||
}
|
||||
// TODO(#1534): check the encryption level
|
||||
// if encLevel < p.EncryptionLevel {
|
||||
// return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
|
||||
// }
|
||||
|
||||
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
|
||||
// It is safe to ignore the corner case of packets that just acked packet 0, because
|
||||
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
|
||||
|
@ -243,8 +237,6 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe
|
|||
h.updateLossDetectionAlarm()
|
||||
|
||||
h.garbageCollectSkippedPackets()
|
||||
h.stopWaitingManager.ReceivedAck(ackFrame)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -530,10 +522,6 @@ func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol
|
|||
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
return h.stopWaitingManager.GetStopWaitingFrame(force)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendMode() SendMode {
|
||||
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
|
||||
|
||||
|
@ -592,9 +580,7 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
|
|||
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
|
||||
var handshakePackets []*Packet
|
||||
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.canBeRetransmitted &&
|
||||
p.EncryptionLevel != protocol.EncryptionForwardSecure &&
|
||||
p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
if p.canBeRetransmitted && p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
handshakePackets = append(handshakePackets, p)
|
||||
}
|
||||
return true, nil
|
||||
|
@ -616,7 +602,6 @@ func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
|
|||
return err
|
||||
}
|
||||
h.retransmissionQueue = append(h.retransmissionQueue, p)
|
||||
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
func retransmittablePacket(p *Packet) *Packet {
|
||||
if p.EncryptionLevel == protocol.EncryptionUnspecified {
|
||||
p.EncryptionLevel = protocol.EncryptionForwardSecure
|
||||
p.EncryptionLevel = protocol.Encryption1RTT
|
||||
}
|
||||
if p.Length == 0 {
|
||||
p.Length = 1
|
||||
|
@ -37,7 +37,7 @@ func nonRetransmittablePacket(p *Packet) *Packet {
|
|||
|
||||
func handshakePacket(p *Packet) *Packet {
|
||||
p = retransmittablePacket(p)
|
||||
p.EncryptionLevel = protocol.EncryptionUnencrypted
|
||||
p.EncryptionLevel = protocol.EncryptionInitial
|
||||
return p
|
||||
}
|
||||
|
||||
|
@ -123,8 +123,8 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
|
||||
It("stores the sent time of handshake packets", func() {
|
||||
sendTime := time.Now().Add(-time.Minute)
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1, SendTime: sendTime, EncryptionLevel: protocol.EncryptionUnencrypted}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2, SendTime: sendTime.Add(time.Hour), EncryptionLevel: protocol.EncryptionForwardSecure}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1, SendTime: sendTime, EncryptionLevel: protocol.EncryptionInitial}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2, SendTime: sendTime.Add(time.Hour), EncryptionLevel: protocol.Encryption1RTT}))
|
||||
Expect(handler.lastSentHandshakePacketTime).To(Equal(sendTime))
|
||||
})
|
||||
|
||||
|
@ -205,7 +205,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
ack := &wire.AckFrame{
|
||||
AckRanges: []wire.AckRange{{Smallest: 10, Largest: 12}},
|
||||
}
|
||||
err := handler.ReceivedAck(ack, 1337, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).To(MatchError("InvalidAckData: Received an ACK for a skipped packet number"))
|
||||
})
|
||||
|
||||
|
@ -216,7 +216,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
{Smallest: 10, Largest: 10},
|
||||
},
|
||||
}
|
||||
err := handler.ReceivedAck(ack, 1337, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestAcked).ToNot(BeZero())
|
||||
})
|
||||
|
@ -237,7 +237,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Context("ACK validation", func() {
|
||||
It("accepts ACKs sent in packet 0", func() {
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}}
|
||||
err := handler.ReceivedAck(ack, 0, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 0, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(5)))
|
||||
})
|
||||
|
@ -245,12 +245,12 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
It("rejects duplicate ACKs", func() {
|
||||
ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}}
|
||||
ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 4}}}
|
||||
err := handler.ReceivedAck(ack1, 1337, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack1, 1337, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(3)))
|
||||
// this wouldn't happen in practice
|
||||
// for testing purposes, we pretend send a different ACK frame in a duplicated packet, to be able to verify that it actually doesn't get processed
|
||||
err = handler.ReceivedAck(ack2, 1337, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack2, 1337, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(3)))
|
||||
})
|
||||
|
@ -259,28 +259,28 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
// acks packets 0, 1, 2, 3
|
||||
ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 3}}}
|
||||
ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 4}}}
|
||||
err := handler.ReceivedAck(ack1, 1337, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack1, 1337, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// this wouldn't happen in practive
|
||||
// a receiver wouldn't send an ACK for a lower largest acked in a packet sent later
|
||||
err = handler.ReceivedAck(ack2, 1337-1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack2, 1337-1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(3)))
|
||||
})
|
||||
|
||||
It("rejects ACKs with a too high LargestAcked packet number", func() {
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 9999}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).To(MatchError("InvalidAckData: Received ACK for an unsent package"))
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(10)))
|
||||
})
|
||||
|
||||
It("ignores repeated ACKs", func() {
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 3}}}
|
||||
err := handler.ReceivedAck(ack, 1337, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1337, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7)))
|
||||
err = handler.ReceivedAck(ack, 1337+1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack, 1337+1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(3)))
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(7)))
|
||||
|
@ -290,7 +290,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Context("acks and nacks the right packets", func() {
|
||||
It("adjusts the LargestAcked, and adjusts the bytes in flight", func() {
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 5}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.largestAcked).To(Equal(protocol.PacketNumber(5)))
|
||||
expectInPacketHistory([]protocol.PacketNumber{6, 7, 8, 9})
|
||||
|
@ -299,7 +299,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
|
||||
It("acks packet 0", func() {
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 0, Largest: 0}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(getPacket(0)).To(BeNil())
|
||||
expectInPacketHistory([]protocol.PacketNumber{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||
|
@ -312,14 +312,14 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
{Smallest: 1, Largest: 3},
|
||||
},
|
||||
}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{0, 4, 5})
|
||||
})
|
||||
|
||||
It("does not ack packets below the LowestAcked", func() {
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 8}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{0, 1, 2, 9})
|
||||
})
|
||||
|
@ -333,7 +333,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
{Smallest: 1, Largest: 1},
|
||||
},
|
||||
}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{0, 2, 4, 5, 8})
|
||||
})
|
||||
|
@ -345,12 +345,12 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
{Smallest: 1, Largest: 2},
|
||||
},
|
||||
}
|
||||
err := handler.ReceivedAck(ack1, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack1, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{0, 3, 7, 8, 9})
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(5)))
|
||||
ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}} // now ack 3
|
||||
err = handler.ReceivedAck(ack2, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack2, 2, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{0, 7, 8, 9})
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4)))
|
||||
|
@ -363,12 +363,12 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
{Smallest: 0, Largest: 2},
|
||||
},
|
||||
}
|
||||
err := handler.ReceivedAck(ack1, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack1, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{3, 7, 8, 9})
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4)))
|
||||
ack2 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 7}}}
|
||||
err = handler.ReceivedAck(ack2, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack2, 2, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(2)))
|
||||
expectInPacketHistory([]protocol.PacketNumber{8, 9})
|
||||
|
@ -376,7 +376,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
|
||||
It("processes an ACK that contains old ACK ranges", func() {
|
||||
ack1 := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}}
|
||||
err := handler.ReceivedAck(ack1, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack1, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{0, 7, 8, 9})
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(4)))
|
||||
|
@ -387,7 +387,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
{Smallest: 1, Largest: 1},
|
||||
},
|
||||
}
|
||||
err = handler.ReceivedAck(ack2, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack2, 2, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{0, 7, 9})
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(3)))
|
||||
|
@ -403,15 +403,15 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
getPacket(6).SendTime = now.Add(-1 * time.Minute)
|
||||
// Now, check that the proper times are used when calculating the deltas
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 10*time.Minute, 1*time.Second))
|
||||
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}}
|
||||
err = handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack, 2, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
|
||||
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 6}}}
|
||||
err = handler.ReceivedAck(ack, 3, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack, 3, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 1*time.Minute, 1*time.Second))
|
||||
})
|
||||
|
@ -425,7 +425,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}},
|
||||
DelayTime: 5 * time.Minute,
|
||||
}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.LatestRTT()).To(BeNumerically("~", 5*time.Minute, 1*time.Second))
|
||||
})
|
||||
|
@ -446,27 +446,27 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
})
|
||||
|
||||
It("determines which ACK we have received an ACK for", func() {
|
||||
err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 15}}}, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 15}}}, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
|
||||
})
|
||||
|
||||
It("doesn't do anything when the acked packet didn't contain an ACK", func() {
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
|
||||
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 15, Largest: 15}}}
|
||||
err = handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack, 2, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(101)))
|
||||
})
|
||||
|
||||
It("doesn't decrease the value", func() {
|
||||
err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 14, Largest: 14}}}, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 14, Largest: 14}}}, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
|
||||
err = handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}, 2, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetLowestPacketNotConfirmedAcked()).To(Equal(protocol.PacketNumber(201)))
|
||||
})
|
||||
|
@ -492,7 +492,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11)))
|
||||
// ack 5
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 5, Largest: 5}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expectInPacketHistory([]protocol.PacketNumber{6})
|
||||
Expect(handler.bytesInFlight).To(Equal(protocol.ByteCount(11)))
|
||||
|
@ -510,36 +510,16 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
{Smallest: 5, Largest: 5},
|
||||
},
|
||||
}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.packetHistory.Len()).To(BeZero())
|
||||
Expect(handler.bytesInFlight).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Retransmission handling", func() {
|
||||
It("does not dequeue a packet if no ack has been received", func() {
|
||||
handler.SentPacket(&Packet{PacketNumber: 1})
|
||||
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
|
||||
})
|
||||
|
||||
Context("STOP_WAITINGs", func() {
|
||||
It("gets a STOP_WAITING frame", func() {
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 1}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3}))
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 3}}}
|
||||
err := handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 4}))
|
||||
})
|
||||
|
||||
It("gets a STOP_WAITING frame after queueing a retransmission", func() {
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 5}))
|
||||
handler.queuePacketForRetransmission(getPacket(5))
|
||||
Expect(handler.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 6}))
|
||||
})
|
||||
})
|
||||
It("does not dequeue a packet if no ack has been received", func() {
|
||||
handler.SentPacket(&Packet{PacketNumber: 1})
|
||||
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
|
||||
})
|
||||
|
||||
Context("congestion", func() {
|
||||
|
@ -580,7 +560,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 2}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3}))
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, rcvTime)
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, rcvTime)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -618,7 +598,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
)
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 5}))
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 5, Largest: 5}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, rcvTime)
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, rcvTime)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -641,7 +621,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(3)),
|
||||
)
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -657,11 +637,11 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(2)),
|
||||
)
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// don't EXPECT any further calls to the congestion controller
|
||||
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 2}}}
|
||||
err = handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack, 2, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -679,7 +659,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
cong.EXPECT().OnPacketLost(protocol.PacketNumber(1), protocol.ByteCount(1), protocol.ByteCount(4)),
|
||||
)
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now().Add(-30*time.Minute))
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now().Add(-30*time.Minute))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// receive the second ACK
|
||||
gomock.InOrder(
|
||||
|
@ -688,7 +668,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
cong.EXPECT().OnPacketLost(protocol.PacketNumber(3), protocol.ByteCount(1), protocol.ByteCount(2)),
|
||||
)
|
||||
ack = &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 4, Largest: 4}}}
|
||||
err = handler.ReceivedAck(ack, 2, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack, 2, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
|
@ -778,7 +758,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 10}))
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 11}))
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.GetAlarmTimeout()).To(BeZero())
|
||||
})
|
||||
|
@ -906,7 +886,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
// This verifies the RTO.
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 3}))
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 3, Largest: 3}}}
|
||||
err = handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.packetHistory.Len()).To(BeZero())
|
||||
Expect(handler.bytesInFlight).To(BeZero())
|
||||
|
@ -945,7 +925,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
// This verifies the RTO.
|
||||
handler.SentPacket(retransmittablePacket(&Packet{PacketNumber: 6}))
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 6, Largest: 6}}}
|
||||
err = handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err = handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handler.packetHistory.Len()).To(BeZero())
|
||||
Expect(handler.bytesInFlight).To(BeZero())
|
||||
|
@ -960,7 +940,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
handler.OnAlarm() // RTO
|
||||
handler.SentPacketsAsRetransmission([]*Packet{retransmittablePacket(&Packet{PacketNumber: 6})}, 5)
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 5, Largest: 5}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, time.Now())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = handler.OnAlarm()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -984,7 +964,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Expect(handler.lossTime.IsZero()).To(BeTrue())
|
||||
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, now)
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, now)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.DequeuePacketForRetransmission()).ToNot(BeNil())
|
||||
Expect(handler.DequeuePacketForRetransmission()).To(BeNil())
|
||||
|
@ -1001,7 +981,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
Expect(handler.lossTime.IsZero()).To(BeTrue())
|
||||
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, now.Add(-time.Second))
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, now.Add(-time.Second))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Second))
|
||||
|
||||
|
@ -1033,7 +1013,7 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
handler.SentPacket(handshakePacket(&Packet{PacketNumber: 3, SendTime: sendTime}))
|
||||
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionForwardSecure, now)
|
||||
err := handler.ReceivedAck(ack, 1, protocol.Encryption1RTT, now)
|
||||
// RTT is now 1 minute
|
||||
Expect(handler.rttStats.SmoothedRTT()).To(Equal(time.Minute))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -1055,19 +1035,19 @@ var _ = Describe("SentPacketHandler", func() {
|
|||
PIt("rejects an ACK that acks packets with a higher encryption level", func() {
|
||||
handler.SentPacket(&Packet{
|
||||
PacketNumber: 13,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
Frames: []wire.Frame{&streamFrame},
|
||||
Length: 1,
|
||||
})
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 13, Largest: 13}}}
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionSecure, time.Now())
|
||||
err := handler.ReceivedAck(ack, 1, protocol.EncryptionHandshake, time.Now())
|
||||
Expect(err).To(MatchError("Received ACK with encryption level encrypted (not forward-secure) that acks a packet 13 (encryption level forward-secure)"))
|
||||
})
|
||||
|
||||
It("deletes non forward-secure packets when the handshake completes", func() {
|
||||
It("deletes non handshake packets when the handshake completes", func() {
|
||||
for i := protocol.PacketNumber(1); i <= 6; i++ {
|
||||
p := retransmittablePacket(&Packet{PacketNumber: i})
|
||||
p.EncryptionLevel = protocol.EncryptionSecure
|
||||
p.EncryptionLevel = protocol.EncryptionHandshake
|
||||
handler.SentPacket(p)
|
||||
}
|
||||
handler.queuePacketForRetransmission(getPacket(1))
|
||||
|
|
|
@ -35,8 +35,7 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
|
|||
}
|
||||
if p.canBeRetransmitted {
|
||||
h.numOutstandingPackets++
|
||||
if p.EncryptionLevel != protocol.EncryptionForwardSecure &&
|
||||
p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
if p.EncryptionLevel != protocol.Encryption1RTT {
|
||||
h.numOutstandingHandshakePackets++
|
||||
}
|
||||
}
|
||||
|
@ -107,8 +106,7 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber)
|
|||
if h.numOutstandingPackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
if el.Value.EncryptionLevel != protocol.EncryptionForwardSecure &&
|
||||
el.Value.EncryptionLevel != protocol.Encryption1RTT {
|
||||
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
|
||||
h.numOutstandingHandshakePackets--
|
||||
if h.numOutstandingHandshakePackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
|
@ -149,8 +147,7 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
|||
if h.numOutstandingPackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
}
|
||||
if el.Value.EncryptionLevel != protocol.EncryptionForwardSecure &&
|
||||
el.Value.EncryptionLevel != protocol.Encryption1RTT {
|
||||
if el.Value.EncryptionLevel != protocol.Encryption1RTT {
|
||||
h.numOutstandingHandshakePackets--
|
||||
if h.numOutstandingHandshakePackets < 0 {
|
||||
panic("numOutstandingHandshakePackets negative")
|
||||
|
|
|
@ -202,7 +202,7 @@ var _ = Describe("SentPacketHistory", func() {
|
|||
It("says if it has outstanding handshake packets", func() {
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
hist.SentPacket(&Packet{
|
||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
||||
EncryptionLevel: protocol.EncryptionInitial,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
||||
|
@ -212,7 +212,7 @@ var _ = Describe("SentPacketHistory", func() {
|
|||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
||||
hist.SentPacket(&Packet{
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
|
@ -221,7 +221,7 @@ var _ = Describe("SentPacketHistory", func() {
|
|||
|
||||
It("doesn't consider non-retransmittable packets as outstanding", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
||||
EncryptionLevel: protocol.EncryptionInitial,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeFalse())
|
||||
Expect(hist.HasOutstandingPackets()).To(BeFalse())
|
||||
|
@ -230,7 +230,7 @@ var _ = Describe("SentPacketHistory", func() {
|
|||
It("accounts for deleted handshake packets", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 5,
|
||||
EncryptionLevel: protocol.EncryptionSecure,
|
||||
EncryptionLevel: protocol.EncryptionHandshake,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
||||
|
@ -242,7 +242,7 @@ var _ = Describe("SentPacketHistory", func() {
|
|||
It("accounts for deleted packets", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 10,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
||||
|
@ -254,7 +254,7 @@ var _ = Describe("SentPacketHistory", func() {
|
|||
It("doesn't count handshake packets marked as non-retransmittable", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 5,
|
||||
EncryptionLevel: protocol.EncryptionUnencrypted,
|
||||
EncryptionLevel: protocol.EncryptionInitial,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingHandshakePackets()).To(BeTrue())
|
||||
|
@ -266,7 +266,7 @@ var _ = Describe("SentPacketHistory", func() {
|
|||
It("doesn't count packets marked as non-retransmittable", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 10,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
Expect(hist.HasOutstandingPackets()).To(BeTrue())
|
||||
|
@ -278,12 +278,12 @@ var _ = Describe("SentPacketHistory", func() {
|
|||
It("counts the number of packets", func() {
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 10,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
hist.SentPacket(&Packet{
|
||||
PacketNumber: 11,
|
||||
EncryptionLevel: protocol.EncryptionForwardSecure,
|
||||
EncryptionLevel: protocol.Encryption1RTT,
|
||||
canBeRetransmitted: true,
|
||||
})
|
||||
err := hist.Remove(11)
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
|
||||
type stopWaitingManager struct {
|
||||
largestLeastUnackedSent protocol.PacketNumber
|
||||
nextLeastUnacked protocol.PacketNumber
|
||||
|
||||
lastStopWaitingFrame *wire.StopWaitingFrame
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
|
||||
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
|
||||
if force {
|
||||
return s.lastStopWaitingFrame
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.largestLeastUnackedSent = s.nextLeastUnacked
|
||||
swf := &wire.StopWaitingFrame{
|
||||
LeastUnacked: s.nextLeastUnacked,
|
||||
}
|
||||
s.lastStopWaitingFrame = swf
|
||||
return swf
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
|
||||
largestAcked := ack.LargestAcked()
|
||||
if largestAcked >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = largestAcked + 1
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
|
||||
if p >= s.nextLeastUnacked {
|
||||
s.nextLeastUnacked = p + 1
|
||||
}
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/wire"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("StopWaitingManager", func() {
|
||||
var manager *stopWaitingManager
|
||||
BeforeEach(func() {
|
||||
manager = &stopWaitingManager{}
|
||||
})
|
||||
|
||||
It("returns nil in the beginning", func() {
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
|
||||
Expect(manager.GetStopWaitingFrame(true)).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns a StopWaitingFrame, when a new ACK arrives", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
||||
})
|
||||
|
||||
It("does not decrease the LeastUnacked", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 9}}})
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
||||
})
|
||||
|
||||
It("does not send the same StopWaitingFrame twice", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(BeNil())
|
||||
})
|
||||
|
||||
It("gets the same StopWaitingFrame twice, if forced", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
Expect(manager.GetStopWaitingFrame(false)).ToNot(BeNil())
|
||||
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
|
||||
Expect(manager.GetStopWaitingFrame(true)).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("increases the LeastUnacked when a retransmission is queued", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
manager.QueuedRetransmissionForPacketNumber(20)
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 21}))
|
||||
})
|
||||
|
||||
It("does not decrease the LeastUnacked when a retransmission is queued", func() {
|
||||
manager.ReceivedAck(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}})
|
||||
manager.QueuedRetransmissionForPacketNumber(9)
|
||||
Expect(manager.GetStopWaitingFrame(false)).To(Equal(&wire.StopWaitingFrame{LeastUnacked: 11}))
|
||||
})
|
||||
})
|
|
@ -1,72 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/lucas-clemente/aes12"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type aeadAESGCM12 struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
var _ AEAD = &aeadAESGCM12{}
|
||||
|
||||
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
|
||||
//
|
||||
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
|
||||
// tag size, and couples the cipher and aes packages closely.
|
||||
// See https://github.com/lucas-clemente/aes12.
|
||||
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
|
||||
}
|
||||
encrypterCipher, err := aes12.NewCipher(myKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encrypter, err := aes12.NewGCM(encrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypterCipher, err := aes12.NewCipher(otherKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := aes12.NewGCM(decrypterCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aeadAESGCM12{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
||||
|
||||
func (aead *aeadAESGCM12) Overhead() int {
|
||||
return aead.encrypter.Overhead()
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("AES-GCM", func() {
|
||||
var (
|
||||
alice, bob AEAD
|
||||
keyAlice, keyBob, ivAlice, ivBob []byte
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
keyAlice = make([]byte, 16)
|
||||
keyBob = make([]byte, 16)
|
||||
ivAlice = make([]byte, 4)
|
||||
ivBob = make([]byte, 4)
|
||||
rand.Reader.Read(keyAlice)
|
||||
rand.Reader.Read(keyBob)
|
||||
rand.Reader.Read(ivAlice)
|
||||
rand.Reader.Read(ivBob)
|
||||
var err error
|
||||
alice, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob, ivAlice)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
bob, err = NewAEADAESGCM12(keyAlice, keyBob, ivAlice, ivBob)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("seals and opens", func() {
|
||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
text, err := bob.Open(nil, b, 42, []byte("aad"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(text).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("seals and opens reverse", func() {
|
||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
text, err := alice.Open(nil, b, 42, []byte("aad"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(text).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("has the proper length", func() {
|
||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
Expect(b).To(HaveLen(6 + bob.Overhead()))
|
||||
})
|
||||
|
||||
It("fails with wrong aad", func() {
|
||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
_, err := bob.Open(nil, b, 42, []byte("aad2"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects wrong key and iv sizes", func() {
|
||||
var err error
|
||||
e := "AES-GCM: expected 16-byte keys and 4-byte IVs"
|
||||
_, err = NewAEADAESGCM12(keyBob[1:], keyAlice, ivBob, ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADAESGCM12(keyBob, keyAlice[1:], ivBob, ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob[1:], ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADAESGCM12(keyBob, keyAlice, ivBob, ivAlice[1:])
|
||||
Expect(err).To(MatchError(e))
|
||||
})
|
||||
})
|
|
@ -1,48 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
compressedCertsCache *lru.Cache
|
||||
)
|
||||
|
||||
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||
// Hash all inputs
|
||||
hasher := fnv.New64a()
|
||||
for _, v := range chain {
|
||||
hasher.Write(v)
|
||||
}
|
||||
hasher.Write(pCommonSetHashes)
|
||||
hasher.Write(pCachedHashes)
|
||||
hash := hasher.Sum64()
|
||||
|
||||
var result []byte
|
||||
|
||||
resultI, isCached := compressedCertsCache.Get(hash)
|
||||
if isCached {
|
||||
result = resultI.([]byte)
|
||||
} else {
|
||||
var err error
|
||||
result, err = compressChain(chain, pCommonSetHashes, pCachedHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
compressedCertsCache.Add(hash, result)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
compressedCertsCache, err = lru.New(protocol.NumCachedCertificates)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error()))
|
||||
}
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Certificate cache", func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
compressedCertsCache, err = lru.New(2)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("gives a compressed cert", func() {
|
||||
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
|
||||
expected, err := compressChain(chain, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
compressed, err := getCompressedCert(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(compressed).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("gets the same result multiple times", func() {
|
||||
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
|
||||
compressed, err := getCompressedCert(chain, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
compressed2, err := getCompressedCert(chain, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(compressed).To(Equal(compressed2))
|
||||
})
|
||||
|
||||
It("stores cached values", func() {
|
||||
chain := [][]byte{{0xde, 0xca, 0xfb, 0xad}}
|
||||
_, err := getCompressedCert(chain, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(compressedCertsCache.Len()).To(Equal(1))
|
||||
Expect(compressedCertsCache.Contains(uint64(3838929964809501833))).To(BeTrue())
|
||||
})
|
||||
|
||||
It("evicts old values", func() {
|
||||
_, err := getCompressedCert([][]byte{{0x00}}, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = getCompressedCert([][]byte{{0x01}}, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_, err = getCompressedCert([][]byte{{0x02}}, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(compressedCertsCache.Len()).To(Equal(2))
|
||||
})
|
||||
})
|
|
@ -1,113 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A CertChain holds a certificate and a private key
|
||||
type CertChain interface {
|
||||
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
|
||||
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
|
||||
GetLeafCert(sni string) ([]byte, error)
|
||||
}
|
||||
|
||||
// proofSource stores a key and a certificate for the server proof
|
||||
type certChain struct {
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
var _ CertChain = &certChain{}
|
||||
|
||||
var errNoMatchingCertificate = errors.New("no matching certificate found")
|
||||
|
||||
// NewCertChain loads the key and cert from files
|
||||
func NewCertChain(tlsConfig *tls.Config) CertChain {
|
||||
return &certChain{config: tlsConfig}
|
||||
}
|
||||
|
||||
// SignServerProof signs CHLO and server config for use in the server proof
|
||||
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
|
||||
cert, err := c.getCertForSNI(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signServerProof(cert, chlo, serverConfigData)
|
||||
}
|
||||
|
||||
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
|
||||
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||
cert, err := c.getCertForSNI(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes)
|
||||
}
|
||||
|
||||
// GetLeafCert gets the leaf certificate
|
||||
func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
|
||||
cert, err := c.getCertForSNI(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cert.Certificate[0], nil
|
||||
}
|
||||
|
||||
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
|
||||
conf := c.config
|
||||
conf, err := maybeGetConfigForClient(conf, sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The rest of this function is mostly copied from crypto/tls.getCertificate
|
||||
|
||||
if conf.GetCertificate != nil {
|
||||
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
|
||||
if cert != nil || err != nil {
|
||||
return cert, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(conf.Certificates) == 0 {
|
||||
return nil, errNoMatchingCertificate
|
||||
}
|
||||
|
||||
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
|
||||
// There's only one choice, so no point doing any work.
|
||||
return &conf.Certificates[0], nil
|
||||
}
|
||||
|
||||
name := strings.ToLower(sni)
|
||||
for len(name) > 0 && name[len(name)-1] == '.' {
|
||||
name = name[:len(name)-1]
|
||||
}
|
||||
|
||||
if cert, ok := conf.NameToCertificate[name]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// try replacing labels in the name with wildcards until we get a
|
||||
// match.
|
||||
labels := strings.Split(name, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if cert, ok := conf.NameToCertificate[candidate]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If nothing matches, return the first certificate.
|
||||
return &conf.Certificates[0], nil
|
||||
}
|
||||
|
||||
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
|
||||
if c.GetConfigForClient == nil {
|
||||
return c, nil
|
||||
}
|
||||
return c.GetConfigForClient(&tls.ClientHelloInfo{
|
||||
ServerName: sni,
|
||||
})
|
||||
}
|
|
@ -1,148 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/zlib"
|
||||
"crypto/tls"
|
||||
"reflect"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Proof", func() {
|
||||
var (
|
||||
cc *certChain
|
||||
config *tls.Config
|
||||
cert tls.Certificate
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
cert = testdata.GetCertificate()
|
||||
config = &tls.Config{}
|
||||
cc = NewCertChain(config).(*certChain)
|
||||
})
|
||||
|
||||
Context("certificate compression", func() {
|
||||
It("compresses certs", func() {
|
||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert)
|
||||
z.Close()
|
||||
kd := &certChain{
|
||||
config: &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
{Certificate: [][]byte{cert}},
|
||||
},
|
||||
},
|
||||
}
|
||||
certCompressed, err := kd.GetCertsCompressed("", nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(certCompressed).To(Equal(append([]byte{
|
||||
0x01, 0x00,
|
||||
0x08, 0x00, 0x00, 0x00,
|
||||
}, certZlib.Bytes()...)))
|
||||
})
|
||||
|
||||
It("errors when it can't retrieve a certificate", func() {
|
||||
_, err := cc.GetCertsCompressed("invalid domain", nil, nil)
|
||||
Expect(err).To(MatchError(errNoMatchingCertificate))
|
||||
})
|
||||
})
|
||||
|
||||
Context("signing server configs", func() {
|
||||
It("errors when it can't retrieve a certificate for the requested SNI", func() {
|
||||
_, err := cc.SignServerProof("invalid", []byte("chlo"), []byte("scfg"))
|
||||
Expect(err).To(MatchError(errNoMatchingCertificate))
|
||||
})
|
||||
|
||||
It("signs the server config", func() {
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
proof, err := cc.SignServerProof("", []byte("chlo"), []byte("scfg"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(proof).ToNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("retrieving certificates", func() {
|
||||
It("errors without certificates", func() {
|
||||
_, err := cc.getCertForSNI("")
|
||||
Expect(err).To(MatchError(errNoMatchingCertificate))
|
||||
})
|
||||
|
||||
It("uses first certificate in config.Certificates", func() {
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
cert, err := cc.getCertForSNI("")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cert.PrivateKey).ToNot(BeNil())
|
||||
Expect(cert.Certificate[0]).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("uses NameToCertificate entries", func() {
|
||||
config.Certificates = []tls.Certificate{cert, cert} // two entries so the long path is used
|
||||
config.NameToCertificate = map[string]*tls.Certificate{
|
||||
"quic.clemente.io": &cert,
|
||||
}
|
||||
cert, err := cc.getCertForSNI("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cert.PrivateKey).ToNot(BeNil())
|
||||
Expect(cert.Certificate[0]).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("uses NameToCertificate entries with wildcard", func() {
|
||||
config.Certificates = []tls.Certificate{cert, cert} // two entries so the long path is used
|
||||
config.NameToCertificate = map[string]*tls.Certificate{
|
||||
"*.clemente.io": &cert,
|
||||
}
|
||||
cert, err := cc.getCertForSNI("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cert.PrivateKey).ToNot(BeNil())
|
||||
Expect(cert.Certificate[0]).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("uses GetCertificate", func() {
|
||||
config.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
Expect(clientHello.ServerName).To(Equal("quic.clemente.io"))
|
||||
return &cert, nil
|
||||
}
|
||||
cert, err := cc.getCertForSNI("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cert.PrivateKey).ToNot(BeNil())
|
||||
Expect(cert.Certificate[0]).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("gets leaf certificates", func() {
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
cert2, err := cc.GetLeafCert("")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cert2).To(Equal(cert.Certificate[0]))
|
||||
})
|
||||
|
||||
It("errors when it can't retrieve a leaf certificate", func() {
|
||||
_, err := cc.GetLeafCert("invalid domain")
|
||||
Expect(err).To(MatchError(errNoMatchingCertificate))
|
||||
})
|
||||
|
||||
It("respects GetConfigForClient", func() {
|
||||
if !reflect.ValueOf(tls.Config{}).FieldByName("GetConfigForClient").IsValid() {
|
||||
// Pre 1.8, we don't have to do anything
|
||||
return
|
||||
}
|
||||
nestedConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
l := func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
Expect(chi.ServerName).To(Equal("quic.clemente.io"))
|
||||
return nestedConfig, nil
|
||||
}
|
||||
reflect.ValueOf(config).Elem().FieldByName("GetConfigForClient").Set(reflect.ValueOf(l))
|
||||
resultCert, err := cc.getCertForSNI("quic.clemente.io")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(*resultCert).To(Equal(cert))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,272 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type entryType uint8
|
||||
|
||||
const (
|
||||
entryCompressed entryType = 1
|
||||
entryCached entryType = 2
|
||||
entryCommon entryType = 3
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
t entryType
|
||||
h uint64 // set hash
|
||||
i uint32 // index
|
||||
}
|
||||
|
||||
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
|
||||
res := &bytes.Buffer{}
|
||||
|
||||
cachedHashes, err := splitHashes(pCachedHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setHashes, err := splitHashes(pCommonSetHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chainHashes := make([]uint64, len(chain))
|
||||
for i := range chain {
|
||||
chainHashes[i] = HashCert(chain[i])
|
||||
}
|
||||
|
||||
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
|
||||
|
||||
totalUncompressedLen := 0
|
||||
for i, e := range entries {
|
||||
res.WriteByte(uint8(e.t))
|
||||
switch e.t {
|
||||
case entryCached:
|
||||
utils.LittleEndian.WriteUint64(res, e.h)
|
||||
case entryCommon:
|
||||
utils.LittleEndian.WriteUint64(res, e.h)
|
||||
utils.LittleEndian.WriteUint32(res, e.i)
|
||||
case entryCompressed:
|
||||
totalUncompressedLen += 4 + len(chain[i])
|
||||
}
|
||||
}
|
||||
res.WriteByte(0) // end of list
|
||||
|
||||
if totalUncompressedLen > 0 {
|
||||
gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
|
||||
}
|
||||
|
||||
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
|
||||
|
||||
for i, e := range entries {
|
||||
if e.t != entryCompressed {
|
||||
continue
|
||||
}
|
||||
lenCert := len(chain[i])
|
||||
gz.Write([]byte{
|
||||
byte(lenCert & 0xff),
|
||||
byte((lenCert >> 8) & 0xff),
|
||||
byte((lenCert >> 16) & 0xff),
|
||||
byte((lenCert >> 24) & 0xff),
|
||||
})
|
||||
gz.Write(chain[i])
|
||||
}
|
||||
|
||||
gz.Close()
|
||||
}
|
||||
|
||||
return res.Bytes(), nil
|
||||
}
|
||||
|
||||
func decompressChain(data []byte) ([][]byte, error) {
|
||||
var chain [][]byte
|
||||
var entries []entry
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
var numCerts int
|
||||
var hasCompressedCerts bool
|
||||
for {
|
||||
entryTypeByte, err := r.ReadByte()
|
||||
if entryTypeByte == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
et := entryType(entryTypeByte)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numCerts++
|
||||
|
||||
switch et {
|
||||
case entryCached:
|
||||
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
|
||||
return nil, errors.New("unexpected cached certificate")
|
||||
case entryCommon:
|
||||
e := entry{t: entryCommon}
|
||||
e.h, err = utils.LittleEndian.ReadUint64(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.i, err = utils.LittleEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certSet, ok := certSets[e.h]
|
||||
if !ok {
|
||||
return nil, errors.New("unknown certSet")
|
||||
}
|
||||
if e.i >= uint32(len(certSet)) {
|
||||
return nil, errors.New("certificate not found in certSet")
|
||||
}
|
||||
entries = append(entries, e)
|
||||
chain = append(chain, certSet[e.i])
|
||||
case entryCompressed:
|
||||
hasCompressedCerts = true
|
||||
entries = append(entries, entry{t: entryCompressed})
|
||||
chain = append(chain, nil)
|
||||
default:
|
||||
return nil, errors.New("unknown entryType")
|
||||
}
|
||||
}
|
||||
|
||||
if numCerts == 0 {
|
||||
return make([][]byte, 0), nil
|
||||
}
|
||||
|
||||
if hasCompressedCerts {
|
||||
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
|
||||
if err != nil {
|
||||
fmt.Println(4)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zlibDict := buildZlibDictForEntries(entries, chain)
|
||||
gz, err := zlib.NewReaderDict(r, zlibDict)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
var totalLength uint32
|
||||
var certIndex int
|
||||
for totalLength < uncompressedLength {
|
||||
lenBytes := make([]byte, 4)
|
||||
_, err := gz.Read(lenBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certLen := binary.LittleEndian.Uint32(lenBytes)
|
||||
|
||||
cert := make([]byte, certLen)
|
||||
n, err := gz.Read(cert)
|
||||
if uint32(n) != certLen && err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if certIndex >= len(entries) {
|
||||
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
|
||||
}
|
||||
if entries[certIndex].t == entryCompressed {
|
||||
chain[certIndex] = cert
|
||||
certIndex++
|
||||
break
|
||||
}
|
||||
certIndex++
|
||||
}
|
||||
|
||||
totalLength += 4 + certLen
|
||||
}
|
||||
}
|
||||
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
|
||||
res := make([]entry, len(chain))
|
||||
chainLoop:
|
||||
for i := range chain {
|
||||
// Check if hash is in cachedHashes
|
||||
for j := range cachedHashes {
|
||||
if chainHashes[i] == cachedHashes[j] {
|
||||
res[i] = entry{t: entryCached, h: chainHashes[i]}
|
||||
continue chainLoop
|
||||
}
|
||||
}
|
||||
|
||||
// Go through common sets and check if it's in there
|
||||
for _, setHash := range setHashes {
|
||||
set, ok := certSets[setHash]
|
||||
if !ok {
|
||||
// We don't have this set
|
||||
continue
|
||||
}
|
||||
// We have this set, check if chain[i] is in the set
|
||||
pos := set.findCertInSet(chain[i])
|
||||
if pos >= 0 {
|
||||
// Found
|
||||
res[i] = entry{t: entryCommon, h: setHash, i: uint32(pos)}
|
||||
continue chainLoop
|
||||
}
|
||||
}
|
||||
|
||||
res[i] = entry{t: entryCompressed}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte {
|
||||
var dict bytes.Buffer
|
||||
|
||||
// First the cached and common in reverse order
|
||||
for i := len(entries) - 1; i >= 0; i-- {
|
||||
if entries[i].t == entryCompressed {
|
||||
continue
|
||||
}
|
||||
dict.Write(chain[i])
|
||||
}
|
||||
|
||||
dict.Write(certDictZlib)
|
||||
return dict.Bytes()
|
||||
}
|
||||
|
||||
func splitHashes(hashes []byte) ([]uint64, error) {
|
||||
if len(hashes)%8 != 0 {
|
||||
return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes")
|
||||
}
|
||||
n := len(hashes) / 8
|
||||
res := make([]uint64, n)
|
||||
for i := 0; i < n; i++ {
|
||||
res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8])
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func getCommonCertificateHashes() []byte {
|
||||
ccs := make([]byte, 8*len(certSets))
|
||||
i := 0
|
||||
for certSetHash := range certSets {
|
||||
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
|
||||
i++
|
||||
}
|
||||
return ccs
|
||||
}
|
||||
|
||||
// HashCert calculates the FNV1a hash of a certificate
|
||||
func HashCert(cert []byte) uint64 {
|
||||
h := fnv.New64a()
|
||||
h.Write(cert)
|
||||
return h.Sum64()
|
||||
}
|
|
@ -1,294 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/lucas-clemente/quic-go-certificates"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func byteHash(d []byte) []byte {
|
||||
h := fnv.New64a()
|
||||
h.Write(d)
|
||||
s := h.Sum64()
|
||||
res := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(res, s)
|
||||
return res
|
||||
}
|
||||
|
||||
var _ = Describe("Cert compression and decompression", func() {
|
||||
var certSetsOld map[uint64]certSet
|
||||
|
||||
BeforeEach(func() {
|
||||
certSetsOld = make(map[uint64]certSet)
|
||||
for s := range certSets {
|
||||
certSetsOld[s] = certSets[s]
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
certSets = certSetsOld
|
||||
})
|
||||
|
||||
It("compresses empty", func() {
|
||||
compressed, err := compressChain(nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(compressed).To(Equal([]byte{0}))
|
||||
})
|
||||
|
||||
It("decompresses empty", func() {
|
||||
compressed, err := compressChain(nil, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
uncompressed, err := decompressChain(compressed)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(uncompressed).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("gives correct single cert", func() {
|
||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert)
|
||||
z.Close()
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(compressed).To(Equal(append([]byte{
|
||||
0x01, 0x00,
|
||||
0x08, 0x00, 0x00, 0x00,
|
||||
}, certZlib.Bytes()...)))
|
||||
})
|
||||
|
||||
It("decompresses a single cert", func() {
|
||||
cert := []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
uncompressed, err := decompressChain(compressed)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(uncompressed).To(Equal(chain))
|
||||
})
|
||||
|
||||
It("gives correct cert and intermediate", func() {
|
||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert1)
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert2)
|
||||
z.Close()
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(compressed).To(Equal(append([]byte{
|
||||
0x01, 0x01, 0x00,
|
||||
0x10, 0x00, 0x00, 0x00,
|
||||
}, certZlib.Bytes()...)))
|
||||
})
|
||||
|
||||
It("decompresses the chain with a cert and an intermediate", func() {
|
||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
decompressed, err := decompressChain(compressed)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decompressed).To(Equal(chain))
|
||||
})
|
||||
|
||||
It("uses cached certificates", func() {
|
||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
certHash := byteHash(cert)
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, nil, certHash)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := append([]byte{0x02}, certHash...)
|
||||
expected = append(expected, 0x00)
|
||||
Expect(compressed).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("uses cached certificates and compressed combined", func() {
|
||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
cert2 := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
cert2Hash := byteHash(cert2)
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert1)
|
||||
z.Close()
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, nil, cert2Hash)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0x01, 0x02}
|
||||
expected = append(expected, cert2Hash...)
|
||||
expected = append(expected, 0x00)
|
||||
expected = append(expected, []byte{0x08, 0, 0, 0}...)
|
||||
expected = append(expected, certZlib.Bytes()...)
|
||||
Expect(compressed).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("uses common certificate sets", func() {
|
||||
cert := certsets.CertSet3[42]
|
||||
setHash := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, setHash, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0x03}
|
||||
expected = append(expected, setHash...)
|
||||
expected = append(expected, []byte{42, 0, 0, 0}...)
|
||||
expected = append(expected, 0x00)
|
||||
Expect(compressed).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("decompresses a single cert form a common certificate set", func() {
|
||||
cert := certsets.CertSet3[42]
|
||||
setHash := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, setHash, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
decompressed, err := decompressChain(compressed)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decompressed).To(Equal(chain))
|
||||
})
|
||||
|
||||
It("decompresses multiple certs form common certificate sets", func() {
|
||||
cert1 := certsets.CertSet3[42]
|
||||
cert2 := certsets.CertSet2[24]
|
||||
setHash := make([]byte, 16)
|
||||
binary.LittleEndian.PutUint64(setHash[0:8], certsets.CertSet3Hash)
|
||||
binary.LittleEndian.PutUint64(setHash[8:16], certsets.CertSet2Hash)
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, setHash, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
decompressed, err := decompressChain(compressed)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decompressed).To(Equal(chain))
|
||||
})
|
||||
|
||||
It("ignores uncommon certificate sets", func() {
|
||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
setHash := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(setHash, 0xdeadbeef)
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, setHash, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, certDictZlib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert)
|
||||
z.Close()
|
||||
Expect(compressed).To(Equal(append([]byte{
|
||||
0x01, 0x00,
|
||||
0x08, 0x00, 0x00, 0x00,
|
||||
}, certZlib.Bytes()...)))
|
||||
})
|
||||
|
||||
It("errors if a common set does not exist", func() {
|
||||
cert := certsets.CertSet3[42]
|
||||
setHash := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, setHash, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
delete(certSets, certsets.CertSet3Hash)
|
||||
_, err = decompressChain(compressed)
|
||||
Expect(err).To(MatchError(errors.New("unknown certSet")))
|
||||
})
|
||||
|
||||
It("errors if a cert in a common set does not exist", func() {
|
||||
certSet := [][]byte{
|
||||
{0x1, 0x2, 0x3, 0x4},
|
||||
{0x5, 0x6, 0x7, 0x8},
|
||||
}
|
||||
certSets[0x1337] = certSet
|
||||
cert := certSet[1]
|
||||
setHash := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(setHash, 0x1337)
|
||||
chain := [][]byte{cert}
|
||||
compressed, err := compressChain(chain, setHash, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
certSets[0x1337] = certSet[:1] // delete the last certificate from the certSet
|
||||
_, err = decompressChain(compressed)
|
||||
Expect(err).To(MatchError(errors.New("certificate not found in certSet")))
|
||||
})
|
||||
|
||||
It("uses common certificates and compressed combined", func() {
|
||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
cert2 := certsets.CertSet3[42]
|
||||
setHash := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
||||
certZlib := &bytes.Buffer{}
|
||||
z, err := zlib.NewWriterLevelDict(certZlib, flate.BestCompression, append(cert2, certDictZlib...))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
z.Write([]byte{0x04, 0x00, 0x00, 0x00})
|
||||
z.Write(cert1)
|
||||
z.Close()
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, setHash, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
expected := []byte{0x01, 0x03}
|
||||
expected = append(expected, setHash...)
|
||||
expected = append(expected, []byte{42, 0, 0, 0}...)
|
||||
expected = append(expected, 0x00)
|
||||
expected = append(expected, []byte{0x08, 0, 0, 0}...)
|
||||
expected = append(expected, certZlib.Bytes()...)
|
||||
Expect(compressed).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("decompresses a certficate from a common set and a compressed cert combined", func() {
|
||||
cert1 := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
cert2 := certsets.CertSet3[42]
|
||||
setHash := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(setHash, certsets.CertSet3Hash)
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, setHash, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
decompressed, err := decompressChain(compressed)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(decompressed).To(Equal(chain))
|
||||
})
|
||||
|
||||
It("rejects invalid CCS / CCRT hashes", func() {
|
||||
cert := []byte{0xde, 0xca, 0xfb, 0xad}
|
||||
chain := [][]byte{cert}
|
||||
_, err := compressChain(chain, []byte("foo"), nil)
|
||||
Expect(err).To(MatchError("expected a multiple of 8 bytes for CCS / CCRT hashes"))
|
||||
_, err = compressChain(chain, nil, []byte("foo"))
|
||||
Expect(err).To(MatchError("expected a multiple of 8 bytes for CCS / CCRT hashes"))
|
||||
})
|
||||
|
||||
Context("common certificate hashes", func() {
|
||||
It("gets the hashes", func() {
|
||||
ccs := getCommonCertificateHashes()
|
||||
Expect(ccs).ToNot(BeEmpty())
|
||||
hashes, err := splitHashes(ccs)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for _, hash := range hashes {
|
||||
Expect(certSets).To(HaveKey(hash))
|
||||
}
|
||||
})
|
||||
|
||||
It("returns an empty slice if there are not common sets", func() {
|
||||
certSets = make(map[uint64]certSet)
|
||||
ccs := getCommonCertificateHashes()
|
||||
Expect(ccs).ToNot(BeNil())
|
||||
Expect(ccs).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,128 +0,0 @@
|
|||
package crypto
|
||||
|
||||
var certDictZlib = []byte{
|
||||
0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04,
|
||||
0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03,
|
||||
0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30,
|
||||
0x5f, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x04, 0x01,
|
||||
0x06, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, 0xfd, 0x6d, 0x01, 0x07,
|
||||
0x17, 0x01, 0x30, 0x33, 0x20, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65,
|
||||
0x64, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x20, 0x53, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x34,
|
||||
0x20, 0x53, 0x53, 0x4c, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31,
|
||||
0x32, 0x20, 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x53, 0x65, 0x72,
|
||||
0x76, 0x65, 0x72, 0x20, 0x43, 0x41, 0x30, 0x2d, 0x61, 0x69, 0x61, 0x2e,
|
||||
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
|
||||
0x2f, 0x45, 0x2d, 0x63, 0x72, 0x6c, 0x2e, 0x76, 0x65, 0x72, 0x69, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x45, 0x2e, 0x63, 0x65,
|
||||
0x72, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01,
|
||||
0x01, 0x05, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x4a, 0x2e, 0x63,
|
||||
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73,
|
||||
0x2f, 0x63, 0x70, 0x73, 0x20, 0x28, 0x63, 0x29, 0x30, 0x30, 0x09, 0x06,
|
||||
0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x30, 0x0d,
|
||||
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
|
||||
0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x7b, 0x30, 0x1d, 0x06, 0x03, 0x55,
|
||||
0x1d, 0x0e, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86,
|
||||
0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01,
|
||||
0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xd2,
|
||||
0x6f, 0x64, 0x6f, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x43, 0x2e,
|
||||
0x63, 0x72, 0x6c, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
|
||||
0x04, 0x14, 0xb4, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69,
|
||||
0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x30, 0x0b, 0x06, 0x03,
|
||||
0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x30, 0x0d, 0x06, 0x09,
|
||||
0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30,
|
||||
0x81, 0xca, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
|
||||
0x02, 0x55, 0x53, 0x31, 0x10, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x08,
|
||||
0x13, 0x07, 0x41, 0x72, 0x69, 0x7a, 0x6f, 0x6e, 0x61, 0x31, 0x13, 0x30,
|
||||
0x11, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0a, 0x53, 0x63, 0x6f, 0x74,
|
||||
0x74, 0x73, 0x64, 0x61, 0x6c, 0x65, 0x31, 0x1a, 0x30, 0x18, 0x06, 0x03,
|
||||
0x55, 0x04, 0x0a, 0x13, 0x11, 0x47, 0x6f, 0x44, 0x61, 0x64, 0x64, 0x79,
|
||||
0x2e, 0x63, 0x6f, 0x6d, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, 0x33,
|
||||
0x30, 0x31, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x2a, 0x68, 0x74, 0x74,
|
||||
0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63,
|
||||
0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79,
|
||||
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74,
|
||||
0x6f, 0x72, 0x79, 0x31, 0x30, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x04, 0x03,
|
||||
0x13, 0x27, 0x47, 0x6f, 0x20, 0x44, 0x61, 0x64, 0x64, 0x79, 0x20, 0x53,
|
||||
0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66,
|
||||
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68,
|
||||
0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55,
|
||||
0x04, 0x05, 0x13, 0x08, 0x30, 0x37, 0x39, 0x36, 0x39, 0x32, 0x38, 0x37,
|
||||
0x30, 0x1e, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d,
|
||||
0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x0c,
|
||||
0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x02, 0x30, 0x00,
|
||||
0x30, 0x1d, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff,
|
||||
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55,
|
||||
0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05,
|
||||
0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07,
|
||||
0x03, 0x02, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff,
|
||||
0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x33, 0x06, 0x03, 0x55, 0x1d,
|
||||
0x1f, 0x04, 0x2c, 0x30, 0x2a, 0x30, 0x28, 0xa0, 0x26, 0xa0, 0x24, 0x86,
|
||||
0x22, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e,
|
||||
0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
|
||||
0x67, 0x64, 0x73, 0x31, 0x2d, 0x32, 0x30, 0x2a, 0x30, 0x28, 0x06, 0x08,
|
||||
0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x1c, 0x68, 0x74,
|
||||
0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, 0x65,
|
||||
0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
|
||||
0x70, 0x73, 0x30, 0x34, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x5a, 0x17,
|
||||
0x0d, 0x31, 0x33, 0x30, 0x35, 0x30, 0x39, 0x06, 0x08, 0x2b, 0x06, 0x01,
|
||||
0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x3a,
|
||||
0x2f, 0x2f, 0x73, 0x30, 0x39, 0x30, 0x37, 0x06, 0x08, 0x2b, 0x06, 0x01,
|
||||
0x05, 0x05, 0x07, 0x02, 0x30, 0x44, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04,
|
||||
0x3d, 0x30, 0x3b, 0x30, 0x39, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86,
|
||||
0xf8, 0x45, 0x01, 0x07, 0x17, 0x06, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
|
||||
0x55, 0x04, 0x06, 0x13, 0x02, 0x47, 0x42, 0x31, 0x1b, 0x53, 0x31, 0x17,
|
||||
0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0e, 0x56, 0x65, 0x72,
|
||||
0x69, 0x53, 0x69, 0x67, 0x6e, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31,
|
||||
0x1f, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x16, 0x56, 0x65,
|
||||
0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x54, 0x72, 0x75, 0x73, 0x74,
|
||||
0x20, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x31, 0x3b, 0x30, 0x39,
|
||||
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x32, 0x54, 0x65, 0x72, 0x6d, 0x73,
|
||||
0x20, 0x6f, 0x66, 0x20, 0x75, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20, 0x68,
|
||||
0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76,
|
||||
0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
|
||||
0x72, 0x70, 0x61, 0x20, 0x28, 0x63, 0x29, 0x30, 0x31, 0x10, 0x30, 0x0e,
|
||||
0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x07, 0x53, 0x31, 0x13, 0x30, 0x11,
|
||||
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0a, 0x47, 0x31, 0x13, 0x30, 0x11,
|
||||
0x06, 0x0b, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x3c, 0x02, 0x01,
|
||||
0x03, 0x13, 0x02, 0x55, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04,
|
||||
0x03, 0x14, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13,
|
||||
0x31, 0x1d, 0x30, 0x1b, 0x06, 0x03, 0x55, 0x04, 0x0f, 0x13, 0x14, 0x50,
|
||||
0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e,
|
||||
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x31, 0x12, 0x31, 0x21, 0x30,
|
||||
0x1f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x18, 0x44, 0x6f, 0x6d, 0x61,
|
||||
0x69, 0x6e, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x20, 0x56,
|
||||
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x64, 0x31, 0x14, 0x31, 0x31,
|
||||
0x30, 0x2f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x28, 0x53, 0x65, 0x65,
|
||||
0x20, 0x77, 0x77, 0x77, 0x2e, 0x72, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63,
|
||||
0x75, 0x72, 0x65, 0x2e, 0x67, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53,
|
||||
0x69, 0x67, 0x6e, 0x31, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x41,
|
||||
0x2e, 0x63, 0x72, 0x6c, 0x56, 0x65, 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e,
|
||||
0x20, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x33, 0x20, 0x45, 0x63, 0x72,
|
||||
0x6c, 0x2e, 0x67, 0x65, 0x6f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x2e, 0x63,
|
||||
0x6f, 0x6d, 0x2f, 0x63, 0x72, 0x6c, 0x73, 0x2f, 0x73, 0x64, 0x31, 0x1a,
|
||||
0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x3a,
|
||||
0x2f, 0x2f, 0x45, 0x56, 0x49, 0x6e, 0x74, 0x6c, 0x2d, 0x63, 0x63, 0x72,
|
||||
0x74, 0x2e, 0x67, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x69, 0x63, 0x65, 0x72,
|
||||
0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x31, 0x6f, 0x63, 0x73, 0x70, 0x2e,
|
||||
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
|
||||
0x30, 0x39, 0x72, 0x61, 0x70, 0x69, 0x64, 0x73, 0x73, 0x6c, 0x2e, 0x63,
|
||||
0x6f, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63,
|
||||
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72,
|
||||
0x79, 0x2f, 0x30, 0x81, 0x80, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05,
|
||||
0x07, 0x01, 0x01, 0x04, 0x74, 0x30, 0x72, 0x30, 0x24, 0x06, 0x08, 0x2b,
|
||||
0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x18, 0x68, 0x74, 0x74,
|
||||
0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x2e, 0x67, 0x6f, 0x64,
|
||||
0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x4a, 0x06,
|
||||
0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x3e, 0x68,
|
||||
0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66,
|
||||
0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64,
|
||||
0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73,
|
||||
0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x67, 0x64, 0x5f, 0x69, 0x6e, 0x74,
|
||||
0x65, 0x72, 0x6d, 0x65, 0x64, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x72,
|
||||
0x74, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16,
|
||||
0x80, 0x14, 0xfd, 0xac, 0x61, 0x32, 0x93, 0x6c, 0x45, 0xd6, 0xe2, 0xee,
|
||||
0x85, 0x5f, 0x9a, 0xba, 0xe7, 0x76, 0x99, 0x68, 0xcc, 0xe7, 0x30, 0x27,
|
||||
0x86, 0x29, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x86, 0x30,
|
||||
0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73,
|
||||
}
|
|
@ -1,135 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"hash/fnv"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// CertManager manages the certificates sent by the server
|
||||
type CertManager interface {
|
||||
SetData([]byte) error
|
||||
GetCommonCertificateHashes() []byte
|
||||
GetLeafCert() []byte
|
||||
GetLeafCertHash() (uint64, error)
|
||||
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
|
||||
Verify(hostname string) error
|
||||
GetChain() []*x509.Certificate
|
||||
}
|
||||
|
||||
type certManager struct {
|
||||
chain []*x509.Certificate
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
var _ CertManager = &certManager{}
|
||||
|
||||
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
|
||||
|
||||
// NewCertManager creates a new CertManager
|
||||
func NewCertManager(tlsConfig *tls.Config) CertManager {
|
||||
return &certManager{config: tlsConfig}
|
||||
}
|
||||
|
||||
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
|
||||
func (c *certManager) SetData(data []byte) error {
|
||||
byteChain, err := decompressChain(data)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
|
||||
}
|
||||
|
||||
chain := make([]*x509.Certificate, len(byteChain))
|
||||
for i, data := range byteChain {
|
||||
cert, err := x509.ParseCertificate(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chain[i] = cert
|
||||
}
|
||||
|
||||
c.chain = chain
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *certManager) GetChain() []*x509.Certificate {
|
||||
return c.chain
|
||||
}
|
||||
|
||||
func (c *certManager) GetCommonCertificateHashes() []byte {
|
||||
return getCommonCertificateHashes()
|
||||
}
|
||||
|
||||
// GetLeafCert returns the leaf certificate of the certificate chain
|
||||
// it returns nil if the certificate chain has not yet been set
|
||||
func (c *certManager) GetLeafCert() []byte {
|
||||
if len(c.chain) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.chain[0].Raw
|
||||
}
|
||||
|
||||
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
|
||||
func (c *certManager) GetLeafCertHash() (uint64, error) {
|
||||
leafCert := c.GetLeafCert()
|
||||
if leafCert == nil {
|
||||
return 0, errNoCertificateChain
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
_, err := h.Write(leafCert)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return h.Sum64(), nil
|
||||
}
|
||||
|
||||
// VerifyServerProof verifies the signature of the server config
|
||||
// it should only be called after the certificate chain has been set, otherwise it returns false
|
||||
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
|
||||
if len(c.chain) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
|
||||
}
|
||||
|
||||
// Verify verifies the certificate chain
|
||||
func (c *certManager) Verify(hostname string) error {
|
||||
if len(c.chain) == 0 {
|
||||
return errNoCertificateChain
|
||||
}
|
||||
|
||||
if c.config != nil && c.config.InsecureSkipVerify {
|
||||
return nil
|
||||
}
|
||||
|
||||
leafCert := c.chain[0]
|
||||
|
||||
var opts x509.VerifyOptions
|
||||
if c.config != nil {
|
||||
opts.Roots = c.config.RootCAs
|
||||
if c.config.Time == nil {
|
||||
opts.CurrentTime = time.Now()
|
||||
} else {
|
||||
opts.CurrentTime = c.config.Time()
|
||||
}
|
||||
}
|
||||
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
|
||||
opts.DNSName = hostname
|
||||
|
||||
// the first certificate is the leaf certificate, all others are intermediates
|
||||
if len(c.chain) > 1 {
|
||||
intermediates := x509.NewCertPool()
|
||||
for i := 1; i < len(c.chain); i++ {
|
||||
intermediates.AddCert(c.chain[i])
|
||||
}
|
||||
opts.Intermediates = intermediates
|
||||
}
|
||||
|
||||
_, err := leafCert.Verify(opts)
|
||||
return err
|
||||
}
|
|
@ -1,348 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"math/big"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Cert Manager", func() {
|
||||
var cm *certManager
|
||||
var key1, key2 *rsa.PrivateKey
|
||||
var cert1, cert2 []byte
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
cm = NewCertManager(nil).(*certManager)
|
||||
key1, err = rsa.GenerateKey(rand.Reader, 768)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
key2, err = rsa.GenerateKey(rand.Reader, 768)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
template := &x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
cert1, err = x509.CreateCertificate(rand.Reader, template, template, &key1.PublicKey, key1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cert2, err = x509.CreateCertificate(rand.Reader, template, template, &key2.PublicKey, key2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("saves a client TLS config", func() {
|
||||
tlsConf := &tls.Config{ServerName: "quic.clemente.io"}
|
||||
cm = NewCertManager(tlsConf).(*certManager)
|
||||
Expect(cm.config.ServerName).To(Equal("quic.clemente.io"))
|
||||
})
|
||||
|
||||
It("errors when given invalid data", func() {
|
||||
err := cm.SetData([]byte("foobar"))
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")))
|
||||
})
|
||||
|
||||
It("gets the common certificate hashes", func() {
|
||||
ccs := cm.GetCommonCertificateHashes()
|
||||
Expect(ccs).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
Context("setting the data", func() {
|
||||
It("decompresses a certificate chain", func() {
|
||||
chain := [][]byte{cert1, cert2}
|
||||
compressed, err := compressChain(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = cm.SetData(compressed)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cm.chain[0].Raw).To(Equal(cert1))
|
||||
Expect(cm.chain[1].Raw).To(Equal(cert2))
|
||||
})
|
||||
|
||||
It("errors if it can't decompress the chain", func() {
|
||||
err := cm.SetData([]byte("invalid data"))
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")))
|
||||
})
|
||||
|
||||
It("errors if it can't parse a certificate", func() {
|
||||
chain := [][]byte{[]byte("cert1"), []byte("cert2")}
|
||||
compressed, err := compressChain(chain, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = cm.SetData(compressed)
|
||||
_, ok := err.(asn1.StructuralError)
|
||||
Expect(ok).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("getting the leaf cert", func() {
|
||||
It("gets it", func() {
|
||||
xcert1, err := x509.ParseCertificate(cert1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
xcert2, err := x509.ParseCertificate(cert2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cm.chain = []*x509.Certificate{xcert1, xcert2}
|
||||
leafCert := cm.GetLeafCert()
|
||||
Expect(leafCert).To(Equal(cert1))
|
||||
})
|
||||
|
||||
It("returns nil if the chain hasn't been set yet", func() {
|
||||
leafCert := cm.GetLeafCert()
|
||||
Expect(leafCert).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("getting the leaf cert hash", func() {
|
||||
It("calculates the FVN1a 64 hash", func() {
|
||||
cm.chain = make([]*x509.Certificate, 1)
|
||||
cm.chain[0] = &x509.Certificate{
|
||||
Raw: []byte("test fnv hash"),
|
||||
}
|
||||
hash, err := cm.GetLeafCertHash()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// hash calculated on http://www.nitrxgen.net/hashgen/
|
||||
Expect(hash).To(Equal(uint64(0x4770f6141fa0f5ad)))
|
||||
})
|
||||
|
||||
It("errors if the certificate chain is not loaded", func() {
|
||||
_, err := cm.GetLeafCertHash()
|
||||
Expect(err).To(MatchError(errNoCertificateChain))
|
||||
})
|
||||
})
|
||||
|
||||
Context("verifying the server config signature", func() {
|
||||
It("returns false when the chain hasn't been set yet", func() {
|
||||
valid := cm.VerifyServerProof([]byte("proof"), []byte("chlo"), []byte("scfg"))
|
||||
Expect(valid).To(BeFalse())
|
||||
})
|
||||
|
||||
It("verifies the signature", func() {
|
||||
chlo := []byte("client hello")
|
||||
scfg := []byte("server config data")
|
||||
xcert1, err := x509.ParseCertificate(cert1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cm.chain = []*x509.Certificate{xcert1}
|
||||
proof, err := signServerProof(&tls.Certificate{PrivateKey: key1}, chlo, scfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
valid := cm.VerifyServerProof(proof, chlo, scfg)
|
||||
Expect(valid).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects an invalid signature", func() {
|
||||
xcert1, err := x509.ParseCertificate(cert1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cm.chain = []*x509.Certificate{xcert1}
|
||||
valid := cm.VerifyServerProof([]byte("invalid proof"), []byte("chlo"), []byte("scfg"))
|
||||
Expect(valid).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("verifying the certificate chain", func() {
|
||||
generateCertificate := func(template, parent *x509.Certificate, pubKey *rsa.PublicKey, privKey *rsa.PrivateKey) *x509.Certificate {
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pubKey, privKey)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return cert
|
||||
}
|
||||
|
||||
getCertificate := func(template *x509.Certificate) (*rsa.PrivateKey, *x509.Certificate) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return key, generateCertificate(template, template, &key.PublicKey, key)
|
||||
}
|
||||
|
||||
It("accepts a valid certificate", func() {
|
||||
cc := NewCertChain(testdata.GetTLSConfig()).(*certChain)
|
||||
tlsCert, err := cc.getCertForSNI("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
for _, data := range tlsCert.Certificate {
|
||||
var cert *x509.Certificate
|
||||
cert, err = x509.ParseCertificate(data)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cm.chain = append(cm.chain, cert)
|
||||
}
|
||||
err = cm.Verify("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("doesn't accept an expired certificate", func() {
|
||||
if runtime.GOOS == "windows" {
|
||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
||||
Skip("windows")
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-25 * time.Hour),
|
||||
NotAfter: time.Now().Add(-time.Hour),
|
||||
}
|
||||
_, leafCert := getCertificate(template)
|
||||
|
||||
cm.chain = []*x509.Certificate{leafCert}
|
||||
err := cm.Verify("")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
|
||||
})
|
||||
|
||||
It("doesn't accept a certificate that is not yet valid", func() {
|
||||
if runtime.GOOS == "windows" {
|
||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
||||
Skip("windows")
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(time.Hour),
|
||||
NotAfter: time.Now().Add(25 * time.Hour),
|
||||
}
|
||||
_, leafCert := getCertificate(template)
|
||||
|
||||
cm.chain = []*x509.Certificate{leafCert}
|
||||
err := cm.Verify("")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
|
||||
})
|
||||
|
||||
It("doesn't accept an certificate for the wrong hostname", func() {
|
||||
if runtime.GOOS == "windows" {
|
||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
||||
Skip("windows")
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
Subject: pkix.Name{CommonName: "google.com"},
|
||||
}
|
||||
_, leafCert := getCertificate(template)
|
||||
|
||||
cm.chain = []*x509.Certificate{leafCert}
|
||||
err := cm.Verify("quic.clemente.io")
|
||||
Expect(err).To(HaveOccurred())
|
||||
_, ok := err.(x509.HostnameError)
|
||||
Expect(ok).To(BeTrue())
|
||||
})
|
||||
|
||||
It("errors if the chain hasn't been set yet", func() {
|
||||
err := cm.Verify("example.com")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
// this tests relies on LetsEncrypt not being contained in the Root CAs
|
||||
It("rejects valid certificate with missing certificate chain", func() {
|
||||
if runtime.GOOS == "windows" {
|
||||
Skip("LetsEncrypt Root CA is included in Windows")
|
||||
}
|
||||
|
||||
cert := testdata.GetCertificate()
|
||||
xcert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cm.chain = []*x509.Certificate{xcert}
|
||||
err = cm.Verify("quic.clemente.io")
|
||||
_, ok := err.(x509.UnknownAuthorityError)
|
||||
Expect(ok).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't do any certificate verification if InsecureSkipVerify is set", func() {
|
||||
if runtime.GOOS == "windows" {
|
||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
||||
Skip("windows")
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
}
|
||||
|
||||
_, leafCert := getCertificate(template)
|
||||
cm.config = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
cm.chain = []*x509.Certificate{leafCert}
|
||||
err := cm.Verify("quic.clemente.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("uses the time specified in a client TLS config", func() {
|
||||
if runtime.GOOS == "windows" {
|
||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
||||
Skip("windows")
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-25 * time.Hour),
|
||||
NotAfter: time.Now().Add(-23 * time.Hour),
|
||||
Subject: pkix.Name{CommonName: "quic.clemente.io"},
|
||||
}
|
||||
_, leafCert := getCertificate(template)
|
||||
cm.chain = []*x509.Certificate{leafCert}
|
||||
cm.config = &tls.Config{
|
||||
Time: func() time.Time { return time.Now().Add(-24 * time.Hour) },
|
||||
}
|
||||
err := cm.Verify("quic.clemente.io")
|
||||
_, ok := err.(x509.UnknownAuthorityError)
|
||||
Expect(ok).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects certificates that are expired at the time specified in a client TLS config", func() {
|
||||
if runtime.GOOS == "windows" {
|
||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
||||
Skip("windows")
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
}
|
||||
_, leafCert := getCertificate(template)
|
||||
cm.chain = []*x509.Certificate{leafCert}
|
||||
cm.config = &tls.Config{
|
||||
Time: func() time.Time { return time.Now().Add(-24 * time.Hour) },
|
||||
}
|
||||
err := cm.Verify("quic.clemente.io")
|
||||
Expect(err.(x509.CertificateInvalidError).Reason).To(Equal(x509.Expired))
|
||||
})
|
||||
|
||||
It("uses the Root CA given in the client config", func() {
|
||||
if runtime.GOOS == "windows" {
|
||||
// certificate validation works different on windows, see https://golang.org/src/crypto/x509/verify.go line 238
|
||||
Skip("windows")
|
||||
}
|
||||
|
||||
templateRoot := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
rootKey, rootCert := getCertificate(templateRoot)
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
Subject: pkix.Name{CommonName: "google.com"},
|
||||
}
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
leafCert := generateCertificate(template, rootCert, &key.PublicKey, rootKey)
|
||||
|
||||
rootCAPool := x509.NewCertPool()
|
||||
rootCAPool.AddCert(rootCert)
|
||||
|
||||
cm.chain = []*x509.Certificate{leafCert}
|
||||
cm.config = &tls.Config{
|
||||
RootCAs: rootCAPool,
|
||||
}
|
||||
err = cm.Verify("google.com")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,24 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go-certificates"
|
||||
)
|
||||
|
||||
type certSet [][]byte
|
||||
|
||||
var certSets = map[uint64]certSet{
|
||||
certsets.CertSet2Hash: certsets.CertSet2,
|
||||
certsets.CertSet3Hash: certsets.CertSet3,
|
||||
}
|
||||
|
||||
// findCertInSet searches for the cert in the set. Negative return value means not found.
|
||||
func (s *certSet) findCertInSet(cert []byte) int {
|
||||
for i, c := range *s {
|
||||
if bytes.Equal(c, cert) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
// +build ignore
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/aead/chacha20"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type aeadChacha20Poly1305 struct {
|
||||
otherIV []byte
|
||||
myIV []byte
|
||||
encrypter cipher.AEAD
|
||||
decrypter cipher.AEAD
|
||||
}
|
||||
|
||||
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
|
||||
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
|
||||
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
|
||||
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
|
||||
}
|
||||
// copy because ChaCha20Poly1305 expects array pointers
|
||||
var MyKey, OtherKey [32]byte
|
||||
copy(MyKey[:], myKey)
|
||||
copy(OtherKey[:], otherKey)
|
||||
|
||||
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &aeadChacha20Poly1305{
|
||||
otherIV: otherIV,
|
||||
myIV: myIV,
|
||||
encrypter: encrypter,
|
||||
decrypter: decrypter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
|
||||
}
|
||||
|
||||
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
|
||||
res := make([]byte, 12)
|
||||
copy(res[0:4], iv)
|
||||
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
|
||||
return res
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
// +build ignore
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Chacha20poly1305", func() {
|
||||
var (
|
||||
alice, bob AEAD
|
||||
keyAlice, keyBob, ivAlice, ivBob []byte
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
keyAlice = make([]byte, 32)
|
||||
keyBob = make([]byte, 32)
|
||||
ivAlice = make([]byte, 4)
|
||||
ivBob = make([]byte, 4)
|
||||
rand.Reader.Read(keyAlice)
|
||||
rand.Reader.Read(keyBob)
|
||||
rand.Reader.Read(ivAlice)
|
||||
rand.Reader.Read(ivBob)
|
||||
var err error
|
||||
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("seals and opens", func() {
|
||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
text, err := bob.Open(nil, b, 42, []byte("aad"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(text).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("seals and opens reverse", func() {
|
||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
text, err := alice.Open(nil, b, 42, []byte("aad"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(text).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("has the proper length", func() {
|
||||
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
Expect(b).To(HaveLen(6 + 12))
|
||||
})
|
||||
|
||||
It("fails with wrong aad", func() {
|
||||
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
_, err := bob.Open(nil, b, 42, []byte("aad2"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects wrong key and iv sizes", func() {
|
||||
var err error
|
||||
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
|
||||
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
|
||||
Expect(err).To(MatchError(e))
|
||||
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
|
||||
Expect(err).To(MatchError(e))
|
||||
})
|
||||
})
|
|
@ -1,41 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
// KeyExchange manages the exchange of keys
|
||||
type curve25519KEX struct {
|
||||
secret [32]byte
|
||||
public [32]byte
|
||||
}
|
||||
|
||||
var _ KeyExchange = &curve25519KEX{}
|
||||
|
||||
// NewCurve25519KEX creates a new KeyExchange using Curve25519, see https://cr.yp.to/ecdh.html
|
||||
func NewCurve25519KEX() (KeyExchange, error) {
|
||||
c := &curve25519KEX{}
|
||||
if _, err := rand.Read(c.secret[:]); err != nil {
|
||||
return nil, errors.New("Curve25519: could not create private key")
|
||||
}
|
||||
curve25519.ScalarBaseMult(&c.public, &c.secret)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *curve25519KEX) PublicKey() []byte {
|
||||
return c.public[:]
|
||||
}
|
||||
|
||||
func (c *curve25519KEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
|
||||
if len(otherPublic) != 32 {
|
||||
return nil, errors.New("Curve25519: expected public key of 32 byte")
|
||||
}
|
||||
var res [32]byte
|
||||
var otherPublicArray [32]byte
|
||||
copy(otherPublicArray[:], otherPublic)
|
||||
curve25519.ScalarMult(&res, &c.secret, &otherPublicArray)
|
||||
return res[:], nil
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ProofRsa", func() {
|
||||
It("works", func() {
|
||||
a, err := NewCurve25519KEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
b, err := NewCurve25519KEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sA, err := a.CalculateSharedKey(b.PublicKey())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sB, err := b.CalculateSharedKey(a.PublicKey())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sA).To(Equal(sB))
|
||||
})
|
||||
|
||||
It("rejects short public keys", func() {
|
||||
a, err := NewCurve25519KEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = a.CalculateSharedKey(nil)
|
||||
Expect(err).To(MatchError("Curve25519: expected public key of 32 byte"))
|
||||
})
|
||||
})
|
|
@ -1,100 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// DeriveKeysChacha20 derives the client and server keys and creates a matching chacha20poly1305 AEAD instance
|
||||
// func DeriveKeysChacha20(version protocol.VersionNumber, forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) {
|
||||
// otherKey, myKey, otherIV, myIV, err := deriveKeys(version, forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 32)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
|
||||
// }
|
||||
|
||||
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
|
||||
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
|
||||
var swap bool
|
||||
if pers == protocol.PerspectiveClient {
|
||||
swap = true
|
||||
}
|
||||
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
|
||||
}
|
||||
|
||||
// deriveKeys derives the keys and the IVs
|
||||
// swap should be set true if generating the values for the client, and false for the server
|
||||
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) {
|
||||
var info bytes.Buffer
|
||||
if forwardSecure {
|
||||
info.Write([]byte("QUIC forward secure key expansion\x00"))
|
||||
} else {
|
||||
info.Write([]byte("QUIC key expansion\x00"))
|
||||
}
|
||||
info.Write(connID)
|
||||
info.Write(chlo)
|
||||
info.Write(scfg)
|
||||
info.Write(cert)
|
||||
|
||||
r := hkdf.New(sha256.New, sharedSecret, nonces, info.Bytes())
|
||||
|
||||
s := make([]byte, 2*keyLen+2*4)
|
||||
if _, err := io.ReadFull(r, s); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
key1 := s[:keyLen]
|
||||
key2 := s[keyLen : 2*keyLen]
|
||||
iv1 := s[2*keyLen : 2*keyLen+4]
|
||||
iv2 := s[2*keyLen+4:]
|
||||
|
||||
var otherKey, myKey []byte
|
||||
var otherIV, myIV []byte
|
||||
|
||||
if !forwardSecure {
|
||||
if err := diversify(key2, iv2, divNonce); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if swap {
|
||||
otherKey = key2
|
||||
myKey = key1
|
||||
otherIV = iv2
|
||||
myIV = iv1
|
||||
} else {
|
||||
otherKey = key1
|
||||
myKey = key2
|
||||
otherIV = iv1
|
||||
myIV = iv2
|
||||
}
|
||||
|
||||
return otherKey, myKey, otherIV, myIV, nil
|
||||
}
|
||||
|
||||
func diversify(key, iv, divNonce []byte) error {
|
||||
secret := make([]byte, len(key)+len(iv))
|
||||
copy(secret, key)
|
||||
copy(secret[len(key):], iv)
|
||||
|
||||
r := hkdf.New(sha256.New, secret, divNonce, []byte("QUIC key diversification"))
|
||||
|
||||
if _, err := io.ReadFull(r, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.ReadFull(r, iv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,197 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("QUIC Crypto Key Derivation", func() {
|
||||
// Context("chacha20poly1305", func() {
|
||||
// It("derives non-fs keys", func() {
|
||||
// aead, err := DeriveKeysChacha20(
|
||||
// protocol.Version32,
|
||||
// false,
|
||||
// []byte("0123456789012345678901"),
|
||||
// []byte("nonce"),
|
||||
// protocol.ConnectionID(42),
|
||||
// []byte("chlo"),
|
||||
// []byte("scfg"),
|
||||
// []byte("cert"),
|
||||
// nil,
|
||||
// )
|
||||
// Expect(err).ToNot(HaveOccurred())
|
||||
// chacha := aead.(*aeadChacha20Poly1305)
|
||||
// // If the IVs match, the keys will match too, since the keys are read earlier
|
||||
// Expect(chacha.myIV).To(Equal([]byte{0xf0, 0xf5, 0x4c, 0xa8}))
|
||||
// Expect(chacha.otherIV).To(Equal([]byte{0x75, 0xd8, 0xa2, 0x8d}))
|
||||
// })
|
||||
//
|
||||
// It("derives fs keys", func() {
|
||||
// aead, err := DeriveKeysChacha20(
|
||||
// protocol.Version32,
|
||||
// true,
|
||||
// []byte("0123456789012345678901"),
|
||||
// []byte("nonce"),
|
||||
// protocol.ConnectionID(42),
|
||||
// []byte("chlo"),
|
||||
// []byte("scfg"),
|
||||
// []byte("cert"),
|
||||
// nil,
|
||||
// )
|
||||
// Expect(err).ToNot(HaveOccurred())
|
||||
// chacha := aead.(*aeadChacha20Poly1305)
|
||||
// // If the IVs match, the keys will match too, since the keys are read earlier
|
||||
// Expect(chacha.myIV).To(Equal([]byte{0xf5, 0x73, 0x11, 0x79}))
|
||||
// Expect(chacha.otherIV).To(Equal([]byte{0xf7, 0x26, 0x4d, 0x2c}))
|
||||
// })
|
||||
//
|
||||
// It("does not use diversification nonces in FS key derivation", func() {
|
||||
// aead, err := DeriveKeysChacha20(
|
||||
// protocol.Version33,
|
||||
// true,
|
||||
// []byte("0123456789012345678901"),
|
||||
// []byte("nonce"),
|
||||
// protocol.ConnectionID(42),
|
||||
// []byte("chlo"),
|
||||
// []byte("scfg"),
|
||||
// []byte("cert"),
|
||||
// []byte("divnonce"),
|
||||
// )
|
||||
// Expect(err).ToNot(HaveOccurred())
|
||||
// chacha := aead.(*aeadChacha20Poly1305)
|
||||
// // If the IVs match, the keys will match too, since the keys are read earlier
|
||||
// Expect(chacha.myIV).To(Equal([]byte{0xf5, 0x73, 0x11, 0x79}))
|
||||
// Expect(chacha.otherIV).To(Equal([]byte{0xf7, 0x26, 0x4d, 0x2c}))
|
||||
// })
|
||||
//
|
||||
// It("uses diversification nonces in initial key derivation", func() {
|
||||
// aead, err := DeriveKeysChacha20(
|
||||
// protocol.Version33,
|
||||
// false,
|
||||
// []byte("0123456789012345678901"),
|
||||
// []byte("nonce"),
|
||||
// protocol.ConnectionID(42),
|
||||
// []byte("chlo"),
|
||||
// []byte("scfg"),
|
||||
// []byte("cert"),
|
||||
// []byte("divnonce"),
|
||||
// )
|
||||
// Expect(err).ToNot(HaveOccurred())
|
||||
// chacha := aead.(*aeadChacha20Poly1305)
|
||||
// // If the IVs match, the keys will match too, since the keys are read earlier
|
||||
// Expect(chacha.myIV).To(Equal([]byte{0xc4, 0x12, 0x25, 0x64}))
|
||||
// Expect(chacha.otherIV).To(Equal([]byte{0x75, 0xd8, 0xa2, 0x8d}))
|
||||
// })
|
||||
// })
|
||||
|
||||
Context("AES-GCM", func() {
|
||||
It("derives non-forward secure keys", func() {
|
||||
aead, err := DeriveQuicCryptoAESKeys(
|
||||
false,
|
||||
[]byte("0123456789012345678901"),
|
||||
[]byte("nonce"),
|
||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
||||
[]byte("chlo"),
|
||||
[]byte("scfg"),
|
||||
[]byte("cert"),
|
||||
[]byte("divnonce"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aesgcm := aead.(*aeadAESGCM12)
|
||||
// If the IVs match, the keys will match too, since the keys are read earlier
|
||||
Expect(aesgcm.myIV).To(Equal([]byte{0x1c, 0xec, 0xac, 0x9b}))
|
||||
Expect(aesgcm.otherIV).To(Equal([]byte{0x64, 0xef, 0x3c, 0x9}))
|
||||
})
|
||||
|
||||
It("uses the diversification nonce when generating non-forwared secure keys", func() {
|
||||
aead1, err := DeriveQuicCryptoAESKeys(
|
||||
false,
|
||||
[]byte("0123456789012345678901"),
|
||||
[]byte("nonce"),
|
||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
||||
[]byte("chlo"),
|
||||
[]byte("scfg"),
|
||||
[]byte("cert"),
|
||||
[]byte("divnonce"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aead2, err := DeriveQuicCryptoAESKeys(
|
||||
false,
|
||||
[]byte("0123456789012345678901"),
|
||||
[]byte("nonce"),
|
||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
||||
[]byte("chlo"),
|
||||
[]byte("scfg"),
|
||||
[]byte("cert"),
|
||||
[]byte("ecnonvid"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aesgcm1 := aead1.(*aeadAESGCM12)
|
||||
aesgcm2 := aead2.(*aeadAESGCM12)
|
||||
Expect(aesgcm1.myIV).ToNot(Equal(aesgcm2.myIV))
|
||||
Expect(aesgcm1.otherIV).To(Equal(aesgcm2.otherIV))
|
||||
})
|
||||
|
||||
It("derives non-forward secure keys, for the other side", func() {
|
||||
aead, err := DeriveQuicCryptoAESKeys(
|
||||
false,
|
||||
[]byte("0123456789012345678901"),
|
||||
[]byte("nonce"),
|
||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
||||
[]byte("chlo"),
|
||||
[]byte("scfg"),
|
||||
[]byte("cert"),
|
||||
[]byte("divnonce"),
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aesgcm := aead.(*aeadAESGCM12)
|
||||
// If the IVs match, the keys will match too, since the keys are read earlier
|
||||
Expect(aesgcm.otherIV).To(Equal([]byte{0x1c, 0xec, 0xac, 0x9b}))
|
||||
Expect(aesgcm.myIV).To(Equal([]byte{0x64, 0xef, 0x3c, 0x9}))
|
||||
})
|
||||
|
||||
It("derives forward secure keys", func() {
|
||||
aead, err := DeriveQuicCryptoAESKeys(
|
||||
true,
|
||||
[]byte("0123456789012345678901"),
|
||||
[]byte("nonce"),
|
||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
||||
[]byte("chlo"),
|
||||
[]byte("scfg"),
|
||||
[]byte("cert"),
|
||||
nil,
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aesgcm := aead.(*aeadAESGCM12)
|
||||
// If the IVs match, the keys will match too, since the keys are read earlier
|
||||
Expect(aesgcm.myIV).To(Equal([]byte{0x7, 0xad, 0xab, 0xb8}))
|
||||
Expect(aesgcm.otherIV).To(Equal([]byte{0xf2, 0x7a, 0xcc, 0x42}))
|
||||
})
|
||||
|
||||
It("does not use div-nonce for FS key derivation", func() {
|
||||
aead, err := DeriveQuicCryptoAESKeys(
|
||||
true,
|
||||
[]byte("0123456789012345678901"),
|
||||
[]byte("nonce"),
|
||||
protocol.ConnectionID([]byte{42, 0, 0, 0, 0, 0, 0, 0}),
|
||||
[]byte("chlo"),
|
||||
[]byte("scfg"),
|
||||
[]byte("cert"),
|
||||
[]byte("divnonce"),
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
aesgcm := aead.(*aeadAESGCM12)
|
||||
// If the IVs match, the keys will match too, since the keys are read earlier
|
||||
Expect(aesgcm.myIV).To(Equal([]byte{0x7, 0xad, 0xab, 0xb8}))
|
||||
Expect(aesgcm.otherIV).To(Equal([]byte{0xf2, 0x7a, 0xcc, 0x42}))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,7 +0,0 @@
|
|||
package crypto
|
||||
|
||||
// KeyExchange manages the exchange of keys
|
||||
type KeyExchange interface {
|
||||
PublicKey() []byte
|
||||
CalculateSharedKey(otherPublic []byte) ([]byte, error)
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
|
||||
// NewNullAEAD creates a NullAEAD
|
||||
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
|
||||
if v.UsesTLS() {
|
||||
return newNullAEADAESGCM(connID, p)
|
||||
}
|
||||
return &nullAEADFNV128a{perspective: p}, nil
|
||||
}
|
|
@ -8,7 +8,8 @@ import (
|
|||
|
||||
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
|
||||
|
||||
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||
// NewNullAEAD creates a NullAEAD
|
||||
func NewNullAEAD(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
|
||||
clientSecret, serverSecret := computeSecrets(connectionID)
|
||||
|
||||
var mySecret, otherSecret []byte
|
||||
|
|
|
@ -56,9 +56,9 @@ var _ = Describe("NullAEAD using AES-GCM", func() {
|
|||
|
||||
It("seals and opens", func() {
|
||||
connectionID := protocol.ConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef})
|
||||
clientAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveClient)
|
||||
clientAEAD, err := NewNullAEAD(connectionID, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAEAD, err := newNullAEADAESGCM(connectionID, protocol.PerspectiveServer)
|
||||
serverAEAD, err := NewNullAEAD(connectionID, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
|
@ -74,9 +74,9 @@ var _ = Describe("NullAEAD using AES-GCM", func() {
|
|||
It("doesn't work if initialized with different connection IDs", func() {
|
||||
c1 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1})
|
||||
c2 := protocol.ConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2})
|
||||
clientAEAD, err := newNullAEADAESGCM(c1, protocol.PerspectiveClient)
|
||||
clientAEAD, err := NewNullAEAD(c1, protocol.PerspectiveClient)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverAEAD, err := newNullAEADAESGCM(c2, protocol.PerspectiveServer)
|
||||
serverAEAD, err := NewNullAEAD(c2, protocol.PerspectiveServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
clientMessage := clientAEAD.Seal(nil, []byte("foobar"), 42, []byte("aad"))
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// nullAEAD handles not-yet encrypted packets
|
||||
type nullAEADFNV128a struct {
|
||||
perspective protocol.Perspective
|
||||
}
|
||||
|
||||
var _ AEAD = &nullAEADFNV128a{}
|
||||
|
||||
// Open and verify the ciphertext
|
||||
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
|
||||
if len(src) < 12 {
|
||||
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
|
||||
}
|
||||
|
||||
hash := fnv.New128a()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src[12:])
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
hash.Write([]byte("Client"))
|
||||
} else {
|
||||
hash.Write([]byte("Server"))
|
||||
}
|
||||
sum := make([]byte, 0, 16)
|
||||
sum = hash.Sum(sum)
|
||||
// The tag is written in little endian, so we need to reverse the slice.
|
||||
reverse(sum)
|
||||
|
||||
if !bytes.Equal(sum[:12], src[:12]) {
|
||||
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
|
||||
}
|
||||
return src[12:], nil
|
||||
}
|
||||
|
||||
// Seal writes hash and ciphertext to the buffer
|
||||
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
|
||||
if cap(dst) < 12+len(src) {
|
||||
dst = make([]byte, 12+len(src))
|
||||
} else {
|
||||
dst = dst[:12+len(src)]
|
||||
}
|
||||
|
||||
hash := fnv.New128a()
|
||||
hash.Write(associatedData)
|
||||
hash.Write(src)
|
||||
|
||||
if n.perspective == protocol.PerspectiveServer {
|
||||
hash.Write([]byte("Server"))
|
||||
} else {
|
||||
hash.Write([]byte("Client"))
|
||||
}
|
||||
sum := make([]byte, 0, 16)
|
||||
sum = hash.Sum(sum)
|
||||
// The tag is written in little endian, so we need to reverse the slice.
|
||||
reverse(sum)
|
||||
|
||||
copy(dst[12:], src)
|
||||
copy(dst, sum[:12])
|
||||
return dst
|
||||
}
|
||||
|
||||
func (n *nullAEADFNV128a) Overhead() int {
|
||||
return 12
|
||||
}
|
||||
|
||||
func reverse(a []byte) {
|
||||
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
|
||||
a[left], a[right] = a[right], a[left]
|
||||
}
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("NullAEAD using FNV128a", func() {
|
||||
aad := []byte("All human beings are born free and equal in dignity and rights.")
|
||||
plainText := []byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")
|
||||
hash36 := []byte{0x98, 0x9b, 0x33, 0x3f, 0xe8, 0xde, 0x32, 0x5c, 0xa6, 0x7f, 0x9c, 0xf7}
|
||||
|
||||
var aeadServer AEAD
|
||||
var aeadClient AEAD
|
||||
|
||||
BeforeEach(func() {
|
||||
aeadServer = &nullAEADFNV128a{protocol.PerspectiveServer}
|
||||
aeadClient = &nullAEADFNV128a{protocol.PerspectiveClient}
|
||||
})
|
||||
|
||||
It("seals and opens, client => server", func() {
|
||||
cipherText := aeadClient.Seal(nil, plainText, 0, aad)
|
||||
res, err := aeadServer.Open(nil, cipherText, 0, aad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")))
|
||||
})
|
||||
|
||||
It("seals and opens, server => client", func() {
|
||||
cipherText := aeadServer.Seal(nil, plainText, 0, aad)
|
||||
res, err := aeadClient.Open(nil, cipherText, 0, aad)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")))
|
||||
})
|
||||
|
||||
It("rejects short ciphertexts", func() {
|
||||
_, err := aeadServer.Open(nil, nil, 0, nil)
|
||||
Expect(err).To(MatchError("NullAEAD: ciphertext cannot be less than 12 bytes long"))
|
||||
})
|
||||
|
||||
It("seals in-place", func() {
|
||||
buf := make([]byte, 6, 12+6)
|
||||
copy(buf, []byte("foobar"))
|
||||
res := aeadServer.Seal(buf[0:0], buf, 0, nil)
|
||||
buf = buf[:12+6]
|
||||
Expect(buf[12:]).To(Equal([]byte("foobar")))
|
||||
Expect(res[12:]).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("fails", func() {
|
||||
cipherText := append(append(hash36, plainText...), byte(0x42))
|
||||
_, err := aeadClient.Open(nil, cipherText, 0, aad)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
|
@ -1,17 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("NullAEAD", func() {
|
||||
It("selects the right FVN variant", func() {
|
||||
connID := protocol.ConnectionID([]byte{0x42, 0, 0, 0, 0, 0, 0, 0})
|
||||
Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.Version39)).To(Equal(&nullAEADFNV128a{
|
||||
perspective: protocol.PerspectiveClient,
|
||||
}))
|
||||
Expect(NewNullAEAD(protocol.PerspectiveClient, connID, protocol.VersionTLS)).To(BeAssignableToTypeOf(&aeadAESGCM{}))
|
||||
})
|
||||
})
|
|
@ -1,66 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
type ecdsaSignature struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
|
||||
// signServerProof signs CHLO and server config for use in the server proof
|
||||
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) {
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
|
||||
chloHash := sha256.Sum256(chlo)
|
||||
hash.Write([]byte{32, 0, 0, 0})
|
||||
hash.Write(chloHash[:])
|
||||
hash.Write(serverConfigData)
|
||||
|
||||
key, ok := cert.PrivateKey.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
|
||||
}
|
||||
|
||||
opts := crypto.SignerOpts(crypto.SHA256)
|
||||
|
||||
if _, ok = key.(*rsa.PrivateKey); ok {
|
||||
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
|
||||
}
|
||||
|
||||
return key.Sign(rand.Reader, hash.Sum(nil), opts)
|
||||
}
|
||||
|
||||
// verifyServerProof verifies the server proof signature
|
||||
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool {
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
|
||||
chloHash := sha256.Sum256(chlo)
|
||||
hash.Write([]byte{32, 0, 0, 0})
|
||||
hash.Write(chloHash[:])
|
||||
hash.Write(serverConfigData)
|
||||
|
||||
// RSA
|
||||
if cert.PublicKeyAlgorithm == x509.RSA {
|
||||
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
|
||||
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ECDSA
|
||||
signature := &ecdsaSignature{}
|
||||
rest, err := asn1.Unmarshal(proof, signature)
|
||||
if err != nil || len(rest) != 0 {
|
||||
return false
|
||||
}
|
||||
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S)
|
||||
}
|
|
@ -1,127 +0,0 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"math/big"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/testdata"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Proof", func() {
|
||||
It("gives valid signatures with the key in internal/testdata", func() {
|
||||
key := &testdata.GetTLSConfig().Certificates[0]
|
||||
signature, err := signServerProof(key, []byte{'C', 'H', 'L', 'O'}, []byte{'S', 'C', 'F', 'G'})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Generated with:
|
||||
// ruby -e 'require "digest"; p Digest::SHA256.digest("QUIC CHLO and server config signature\x00" + "\x20\x00\x00\x00" + Digest::SHA256.digest("CHLO") + "SCFG")'
|
||||
data := []byte("W\xA6\xFC\xDE\xC7\xD2>c\xE6\xB5\xF6\tq\x9E|<~1\xA33\x01\xCA=\x19\xBD\xC1\xE4\xB0\xBA\x9B\x16%")
|
||||
err = rsa.VerifyPSS(key.PrivateKey.(*rsa.PrivateKey).Public().(*rsa.PublicKey), crypto.SHA256, data, signature, &rsa.PSSOptions{SaltLength: 32})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("when using RSA", func() {
|
||||
generateCert := func() (*rsa.PrivateKey, *x509.Certificate) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
certTemplate := x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return key, cert
|
||||
}
|
||||
|
||||
It("verifies a signature", func() {
|
||||
key, cert := generateCert()
|
||||
chlo := []byte("chlo")
|
||||
scfg := []byte("scfg")
|
||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(verifyServerProof(signature, cert, chlo, scfg)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects invalid signatures", func() {
|
||||
key, cert := generateCert()
|
||||
chlo := []byte("client hello")
|
||||
scfg := []byte("sever config")
|
||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(verifyServerProof(append(signature, byte(0x99)), cert, chlo, scfg)).To(BeFalse())
|
||||
Expect(verifyServerProof(signature, cert, chlo[:len(chlo)-2], scfg)).To(BeFalse())
|
||||
Expect(verifyServerProof(signature, cert, chlo, scfg[:len(scfg)-2])).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when using ECDSA", func() {
|
||||
generateCert := func() (*ecdsa.PrivateKey, *x509.Certificate) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
certTemplate := x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
return key, cert
|
||||
}
|
||||
|
||||
It("gives valid signatures", func() {
|
||||
key, _ := generateCert()
|
||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, []byte{'C', 'H', 'L', 'O'}, []byte{'S', 'C', 'F', 'G'})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Generated with:
|
||||
// ruby -e 'require "digest"; p Digest::SHA256.digest("QUIC CHLO and server config signature\x00" + "\x20\x00\x00\x00" + Digest::SHA256.digest("CHLO") + "SCFG")'
|
||||
data := []byte("W\xA6\xFC\xDE\xC7\xD2>c\xE6\xB5\xF6\tq\x9E|<~1\xA33\x01\xCA=\x19\xBD\xC1\xE4\xB0\xBA\x9B\x16%")
|
||||
s := &ecdsaSignature{}
|
||||
_, err = asn1.Unmarshal(signature, s)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
b := ecdsa.Verify(key.Public().(*ecdsa.PublicKey), data, s.R, s.S)
|
||||
Expect(b).To(BeTrue())
|
||||
})
|
||||
|
||||
It("verifies a signature", func() {
|
||||
key, cert := generateCert()
|
||||
chlo := []byte("chlo")
|
||||
scfg := []byte("server config")
|
||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(verifyServerProof(signature, cert, chlo, scfg)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects invalid signatures", func() {
|
||||
key, cert := generateCert()
|
||||
chlo := []byte("client hello")
|
||||
scfg := []byte("server config")
|
||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key}, chlo, scfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(verifyServerProof(append(signature, byte(0x99)), cert, chlo, scfg)).To(BeFalse())
|
||||
Expect(verifyServerProof(signature, cert, chlo[:len(chlo)-2], scfg)).To(BeFalse())
|
||||
Expect(verifyServerProof(signature, cert, chlo, scfg[:len(scfg)-2])).To(BeFalse())
|
||||
})
|
||||
|
||||
It("rejects signatures generated with a different certificate", func() {
|
||||
key1, cert1 := generateCert()
|
||||
key2, cert2 := generateCert()
|
||||
Expect(key1.PublicKey).ToNot(Equal(key2))
|
||||
Expect(cert1.Equal(cert2)).To(BeFalse())
|
||||
chlo := []byte("chlo")
|
||||
scfg := []byte("sfcg")
|
||||
signature, err := signServerProof(&tls.Certificate{PrivateKey: key1}, chlo, scfg)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(verifyServerProof(signature, cert2, chlo, scfg)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
|
@ -16,8 +16,7 @@ type streamFlowController struct {
|
|||
|
||||
queueWindowUpdate func()
|
||||
|
||||
connection connectionFlowControllerI
|
||||
contributesToConnection bool // does the stream contribute to connection level flow control
|
||||
connection connectionFlowControllerI
|
||||
|
||||
receivedFinalOffset bool
|
||||
}
|
||||
|
@ -27,7 +26,6 @@ var _ StreamFlowController = &streamFlowController{}
|
|||
// NewStreamFlowController gets a new flow controller for a stream
|
||||
func NewStreamFlowController(
|
||||
streamID protocol.StreamID,
|
||||
contributesToConnection bool,
|
||||
cfc ConnectionFlowController,
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
|
@ -37,10 +35,9 @@ func NewStreamFlowController(
|
|||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
contributesToConnection: contributesToConnection,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
streamID: streamID,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
|
@ -87,32 +84,21 @@ func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCou
|
|||
if c.checkFlowControlViolation() {
|
||||
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
|
||||
}
|
||||
if c.contributesToConnection {
|
||||
return c.connection.IncrementHighestReceived(increment)
|
||||
}
|
||||
return nil
|
||||
return c.connection.IncrementHighestReceived(increment)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesRead(n)
|
||||
if c.contributesToConnection {
|
||||
c.connection.AddBytesRead(n)
|
||||
}
|
||||
c.connection.AddBytesRead(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesSent(n)
|
||||
if c.contributesToConnection {
|
||||
c.connection.AddBytesSent(n)
|
||||
}
|
||||
c.connection.AddBytesSent(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
window := c.baseFlowController.sendWindowSize()
|
||||
if c.contributesToConnection {
|
||||
window = utils.MinByteCount(window, c.connection.SendWindowSize())
|
||||
}
|
||||
return window
|
||||
return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
|
||||
}
|
||||
|
||||
func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
||||
|
@ -122,9 +108,7 @@ func (c *streamFlowController) MaybeQueueWindowUpdate() {
|
|||
if hasWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
if c.contributesToConnection {
|
||||
c.connection.MaybeQueueWindowUpdate()
|
||||
}
|
||||
c.connection.MaybeQueueWindowUpdate()
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
|
@ -140,9 +124,7 @@ func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
|||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
||||
if c.contributesToConnection {
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
|
|
|
@ -40,12 +40,11 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
|
||||
It("sets the send and receive windows", func() {
|
||||
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
|
||||
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
|
||||
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, nil, rttStats, utils.DefaultLogger).(*streamFlowController)
|
||||
Expect(fc.streamID).To(Equal(protocol.StreamID(5)))
|
||||
Expect(fc.receiveWindow).To(Equal(receiveWindow))
|
||||
Expect(fc.maxReceiveWindowSize).To(Equal(maxReceiveWindow))
|
||||
Expect(fc.sendWindow).To(Equal(sendWindow))
|
||||
Expect(fc.contributesToConnection).To(BeTrue())
|
||||
})
|
||||
|
||||
It("queues window updates with the correction stream ID", func() {
|
||||
|
@ -56,7 +55,7 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
}
|
||||
|
||||
cc := NewConnectionFlowController(0, 0, nil, nil, utils.DefaultLogger)
|
||||
fc := NewStreamFlowController(5, true, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
|
||||
fc := NewStreamFlowController(5, cc, receiveWindow, maxReceiveWindow, sendWindow, queueWindowUpdate, rttStats, utils.DefaultLogger).(*streamFlowController)
|
||||
fc.AddBytesRead(receiveWindow)
|
||||
fc.MaybeQueueWindowUpdate()
|
||||
Expect(queued).To(BeTrue())
|
||||
|
@ -82,21 +81,12 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
|
||||
It("informs the connection flow controller about received data", func() {
|
||||
controller.highestReceived = 10
|
||||
controller.contributesToConnection = true
|
||||
controller.connection.(*connectionFlowController).highestReceived = 100
|
||||
err := controller.UpdateHighestReceived(20, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100 + 10)))
|
||||
})
|
||||
|
||||
It("doesn't informs the connection flow controller about received data if it doesn't contribute", func() {
|
||||
controller.highestReceived = 10
|
||||
controller.connection.(*connectionFlowController).highestReceived = 100
|
||||
err := controller.UpdateHighestReceived(20, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(controller.connection.(*connectionFlowController).highestReceived).To(Equal(protocol.ByteCount(100)))
|
||||
})
|
||||
|
||||
It("does not decrease the highestReceived", func() {
|
||||
controller.highestReceived = 1337
|
||||
err := controller.UpdateHighestReceived(1000, false)
|
||||
|
@ -111,6 +101,7 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
})
|
||||
|
||||
It("does not give a flow control violation when using the window completely", func() {
|
||||
controller.connection.(*connectionFlowController).receiveWindow = receiveWindow
|
||||
err := controller.UpdateHighestReceived(receiveWindow, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
@ -163,19 +154,10 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("registering data read", func() {
|
||||
It("saves when data is read, on a stream not contributing to the connection", func() {
|
||||
controller.AddBytesRead(100)
|
||||
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(100)))
|
||||
Expect(controller.connection.(*connectionFlowController).bytesRead).To(BeZero())
|
||||
})
|
||||
|
||||
It("saves when data is read, on a stream not contributing to the connection", func() {
|
||||
controller.contributesToConnection = true
|
||||
controller.AddBytesRead(200)
|
||||
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200)))
|
||||
Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200)))
|
||||
})
|
||||
It("saves when data is read", func() {
|
||||
controller.AddBytesRead(200)
|
||||
Expect(controller.bytesRead).To(Equal(protocol.ByteCount(200)))
|
||||
Expect(controller.connection.(*connectionFlowController).bytesRead).To(Equal(protocol.ByteCount(200)))
|
||||
})
|
||||
|
||||
Context("generating window updates", func() {
|
||||
|
@ -209,7 +191,6 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
})
|
||||
|
||||
It("queues connection-level window updates", func() {
|
||||
controller.contributesToConnection = true
|
||||
controller.MaybeQueueWindowUpdate()
|
||||
Expect(queuedConnWindowUpdate).To(BeFalse())
|
||||
controller.AddBytesRead(60)
|
||||
|
@ -219,7 +200,6 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
|
||||
It("tells the connection flow controller when the window was autotuned", func() {
|
||||
oldOffset := controller.bytesRead
|
||||
controller.contributesToConnection = true
|
||||
setRtt(scaleDuration(20 * time.Millisecond))
|
||||
controller.epochStartOffset = oldOffset
|
||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
||||
|
@ -230,19 +210,6 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(float64(controller.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)))
|
||||
})
|
||||
|
||||
It("doesn't tell the connection flow controller if it doesn't contribute", func() {
|
||||
oldOffset := controller.bytesRead
|
||||
controller.contributesToConnection = false
|
||||
setRtt(scaleDuration(20 * time.Millisecond))
|
||||
controller.epochStartOffset = oldOffset
|
||||
controller.epochStartTime = time.Now().Add(-time.Millisecond)
|
||||
controller.AddBytesRead(55)
|
||||
offset := controller.GetWindowUpdate()
|
||||
Expect(offset).ToNot(BeZero())
|
||||
Expect(controller.receiveWindowSize).To(Equal(2 * oldWindowSize))
|
||||
Expect(controller.connection.(*connectionFlowController).receiveWindowSize).To(Equal(protocol.ByteCount(2 * oldWindowSize))) // unchanged
|
||||
})
|
||||
|
||||
It("doesn't increase the window after a final offset was already received", func() {
|
||||
controller.AddBytesRead(30)
|
||||
err := controller.UpdateHighestReceived(90, true)
|
||||
|
@ -257,20 +224,13 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
|
||||
Context("sending data", func() {
|
||||
It("gets the size of the send window", func() {
|
||||
controller.connection.UpdateSendWindow(1000)
|
||||
controller.UpdateSendWindow(15)
|
||||
controller.AddBytesSent(5)
|
||||
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10)))
|
||||
})
|
||||
|
||||
It("doesn't care about the connection-level window, if it doesn't contribute", func() {
|
||||
controller.UpdateSendWindow(15)
|
||||
controller.connection.UpdateSendWindow(1)
|
||||
controller.AddBytesSent(5)
|
||||
Expect(controller.SendWindowSize()).To(Equal(protocol.ByteCount(10)))
|
||||
})
|
||||
|
||||
It("makes sure that it doesn't overflow the connection-level window", func() {
|
||||
controller.contributesToConnection = true
|
||||
controller.connection.UpdateSendWindow(12)
|
||||
controller.UpdateSendWindow(20)
|
||||
controller.AddBytesSent(10)
|
||||
|
@ -278,7 +238,6 @@ var _ = Describe("Stream Flow controller", func() {
|
|||
})
|
||||
|
||||
It("doesn't say that it's blocked, if only the connection is blocked", func() {
|
||||
controller.contributesToConnection = true
|
||||
controller.connection.UpdateSendWindow(50)
|
||||
controller.UpdateSendWindow(100)
|
||||
controller.AddBytesSent(50)
|
||||
|
|
|
@ -46,7 +46,7 @@ func (m messageType) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
type cryptoSetupTLS struct {
|
||||
type cryptoSetup struct {
|
||||
tlsConf *qtls.Config
|
||||
|
||||
messageChan chan []byte
|
||||
|
@ -92,8 +92,8 @@ type cryptoSetupTLS struct {
|
|||
perspective protocol.Perspective
|
||||
}
|
||||
|
||||
var _ qtls.RecordLayer = &cryptoSetupTLS{}
|
||||
var _ CryptoSetupTLS = &cryptoSetupTLS{}
|
||||
var _ qtls.RecordLayer = &cryptoSetup{}
|
||||
var _ CryptoSetup = &cryptoSetup{}
|
||||
|
||||
type versionInfo struct {
|
||||
initialVersion protocol.VersionNumber
|
||||
|
@ -101,8 +101,8 @@ type versionInfo struct {
|
|||
currentVersion protocol.VersionNumber
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSClient creates a new TLS crypto setup for the client
|
||||
func NewCryptoSetupTLSClient(
|
||||
// NewCryptoSetupClient creates a new crypto setup for the client
|
||||
func NewCryptoSetupClient(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
|
@ -114,8 +114,8 @@ func NewCryptoSetupTLSClient(
|
|||
currentVersion protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
) (CryptoSetupTLS, <-chan struct{} /* ClientHello written */, error) {
|
||||
return newCryptoSetupTLS(
|
||||
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||
return newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
connID,
|
||||
|
@ -132,8 +132,8 @@ func NewCryptoSetupTLSClient(
|
|||
)
|
||||
}
|
||||
|
||||
// NewCryptoSetupTLSServer creates a new TLS crypto setup for the server
|
||||
func NewCryptoSetupTLSServer(
|
||||
// NewCryptoSetupServer creates a new crypto setup for the server
|
||||
func NewCryptoSetupServer(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
|
@ -144,8 +144,8 @@ func NewCryptoSetupTLSServer(
|
|||
currentVersion protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
) (CryptoSetupTLS, error) {
|
||||
cs, _, err := newCryptoSetupTLS(
|
||||
) (CryptoSetup, error) {
|
||||
cs, _, err := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
connID,
|
||||
|
@ -162,7 +162,7 @@ func NewCryptoSetupTLSServer(
|
|||
return cs, err
|
||||
}
|
||||
|
||||
func newCryptoSetupTLS(
|
||||
func newCryptoSetup(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
|
@ -172,12 +172,12 @@ func newCryptoSetupTLS(
|
|||
versionInfo versionInfo,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
) (CryptoSetupTLS, <-chan struct{} /* ClientHello written */, error) {
|
||||
initialAEAD, err := crypto.NewNullAEAD(perspective, connID, protocol.VersionTLS)
|
||||
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
|
||||
initialAEAD, err := crypto.NewNullAEAD(connID, perspective)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cs := &cryptoSetupTLS{
|
||||
cs := &cryptoSetup{
|
||||
initialStream: initialStream,
|
||||
initialAEAD: initialAEAD,
|
||||
handshakeStream: handshakeStream,
|
||||
|
@ -221,7 +221,7 @@ func newCryptoSetupTLS(
|
|||
return cs, cs.clientHelloWrittenChan, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) RunHandshake() error {
|
||||
func (h *cryptoSetup) RunHandshake() error {
|
||||
var conn *qtls.Conn
|
||||
switch h.perspective {
|
||||
case protocol.PerspectiveClient:
|
||||
|
@ -264,7 +264,7 @@ func (h *cryptoSetupTLS) RunHandshake() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) Close() error {
|
||||
func (h *cryptoSetup) Close() error {
|
||||
close(h.closeChan)
|
||||
// wait until qtls.Handshake() actually returned
|
||||
<-h.handshakeDone
|
||||
|
@ -274,7 +274,7 @@ func (h *cryptoSetupTLS) Close() error {
|
|||
// handleMessage handles a TLS handshake message.
|
||||
// It is called by the crypto streams when a new message is available.
|
||||
// It returns if it is done with messages on the same encryption level.
|
||||
func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
|
||||
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
|
||||
msgType := messageType(data[0])
|
||||
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
|
||||
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
|
||||
|
@ -292,7 +292,7 @@ func (h *cryptoSetupTLS) HandleMessage(data []byte, encLevel protocol.Encryption
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||
var expected protocol.EncryptionLevel
|
||||
switch msgType {
|
||||
case typeClientHello,
|
||||
|
@ -313,7 +313,7 @@ func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel prot
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
|
||||
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
|
||||
switch msgType {
|
||||
case typeClientHello:
|
||||
select {
|
||||
|
@ -358,7 +358,7 @@ func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) bool {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
|
||||
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
|
||||
switch msgType {
|
||||
case typeServerHello:
|
||||
// get the handshake read key
|
||||
|
@ -408,7 +408,7 @@ func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) bool {
|
|||
|
||||
// ReadHandshakeMessage is called by TLS.
|
||||
// It blocks until a new handshake message is available.
|
||||
func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
|
||||
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||
// TODO: add some error handling here (when the session is closed)
|
||||
msg, ok := <-h.messageChan
|
||||
if !ok {
|
||||
|
@ -417,7 +417,7 @@ func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) {
|
|||
return msg, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
|
||||
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
|
||||
opener := newOpener(suite.AEAD(key, iv), iv)
|
||||
|
@ -437,7 +437,7 @@ func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byt
|
|||
h.receivedReadKey <- struct{}{}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
|
||||
key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen())
|
||||
iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen())
|
||||
sealer := newSealer(suite.AEAD(key, iv), iv)
|
||||
|
@ -458,7 +458,7 @@ func (h *cryptoSetupTLS) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []by
|
|||
}
|
||||
|
||||
// WriteRecord is called when TLS writes data
|
||||
func (h *cryptoSetupTLS) WriteRecord(p []byte) (int, error) {
|
||||
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
||||
switch h.writeEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
// assume that the first WriteRecord call contains the ClientHello
|
||||
|
@ -475,7 +475,7 @@ func (h *cryptoSetupTLS) WriteRecord(p []byte) (int, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
if h.sealer != nil {
|
||||
return protocol.Encryption1RTT, h.sealer
|
||||
}
|
||||
|
@ -485,7 +485,7 @@ func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
|||
return protocol.EncryptionInitial, h.initialAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
|
||||
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
|
||||
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
|
||||
|
||||
switch level {
|
||||
|
@ -506,25 +506,25 @@ func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(level protocol.EncryptionL
|
|||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
func (h *cryptoSetup) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
return h.initialAEAD.Open(dst, src, pn, ad)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
func (h *cryptoSetup) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
if h.handshakeOpener == nil {
|
||||
return nil, errors.New("no handshake opener")
|
||||
}
|
||||
return h.handshakeOpener.Open(dst, src, pn, ad)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
func (h *cryptoSetup) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
if h.opener == nil {
|
||||
return nil, errors.New("no 1-RTT opener")
|
||||
}
|
||||
return h.opener.Open(dst, src, pn, ad)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
|
||||
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||
// TODO: return the connection state
|
||||
return ConnectionState{}
|
||||
}
|
|
@ -1,545 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type cryptoSetupClient struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
hostname string
|
||||
connID protocol.ConnectionID
|
||||
version protocol.VersionNumber
|
||||
initialVersion protocol.VersionNumber
|
||||
negotiatedVersions []protocol.VersionNumber
|
||||
|
||||
cryptoStream io.ReadWriter
|
||||
|
||||
serverConfig *serverConfigClient
|
||||
|
||||
stk []byte
|
||||
sno []byte
|
||||
nonc []byte
|
||||
proof []byte
|
||||
chloForSignature []byte
|
||||
lastSentCHLO []byte
|
||||
certManager crypto.CertManager
|
||||
|
||||
divNonceChan chan struct{}
|
||||
diversificationNonce []byte
|
||||
|
||||
clientHelloCounter int
|
||||
serverVerified bool // has the certificate chain and the proof already been verified
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
|
||||
receivedSecurePacket bool
|
||||
nullAEAD crypto.AEAD
|
||||
secureAEAD crypto.AEAD
|
||||
forwardSecureAEAD crypto.AEAD
|
||||
|
||||
paramsChan chan<- TransportParameters
|
||||
handshakeEvent chan<- struct{}
|
||||
handshakeComplete chan<- struct{}
|
||||
|
||||
params *TransportParameters
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupClient{}
|
||||
|
||||
var (
|
||||
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available")
|
||||
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated")
|
||||
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces")
|
||||
)
|
||||
|
||||
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
|
||||
func NewCryptoSetupClient(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
version protocol.VersionNumber,
|
||||
tlsConf *tls.Config,
|
||||
params *TransportParameters,
|
||||
paramsChan chan<- TransportParameters,
|
||||
handshakeEvent chan<- struct{},
|
||||
handshakeComplete chan<- struct{},
|
||||
initialVersion protocol.VersionNumber,
|
||||
negotiatedVersions []protocol.VersionNumber,
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
divNonceChan := make(chan struct{})
|
||||
cs := &cryptoSetupClient{
|
||||
cryptoStream: cryptoStream,
|
||||
hostname: tlsConf.ServerName,
|
||||
connID: connID,
|
||||
version: version,
|
||||
certManager: crypto.NewCertManager(tlsConf),
|
||||
params: params,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
nullAEAD: nullAEAD,
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
handshakeComplete: handshakeComplete,
|
||||
initialVersion: initialVersion,
|
||||
// The server might have sent greased versions in the Version Negotiation packet.
|
||||
// We need strip those from the list, since they won't be included in the handshake tag.
|
||||
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
|
||||
divNonceChan: divNonceChan,
|
||||
logger: logger,
|
||||
}
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) RunHandshake() error {
|
||||
messageChan := make(chan HandshakeMessage)
|
||||
errorChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
message, err := ParseHandshakeMessage(h.cryptoStream)
|
||||
if err != nil {
|
||||
errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error())
|
||||
return
|
||||
}
|
||||
messageChan <- message
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if err := h.maybeUpgradeCrypto(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.mutex.RLock()
|
||||
sendCHLO := h.secureAEAD == nil
|
||||
h.mutex.RUnlock()
|
||||
if sendCHLO {
|
||||
if err := h.sendCHLO(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var message HandshakeMessage
|
||||
select {
|
||||
case <-h.divNonceChan:
|
||||
// there's no message to process, but we should try upgrading the crypto again
|
||||
continue
|
||||
case message = <-messageChan:
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Debugf("Got %s", message)
|
||||
switch message.Tag {
|
||||
case TagREJ:
|
||||
if err := h.handleREJMessage(message.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
case TagSHLO:
|
||||
params, err := h.handleSHLOMessage(message.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// blocks until the session has received the parameters
|
||||
h.paramsChan <- *params
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.handshakeComplete)
|
||||
default:
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
|
||||
var err error
|
||||
|
||||
if stk, ok := cryptoData[TagSTK]; ok {
|
||||
h.stk = stk
|
||||
}
|
||||
|
||||
if sno, ok := cryptoData[TagSNO]; ok {
|
||||
h.sno = sno
|
||||
}
|
||||
|
||||
// TODO: what happens if the server sends a different server config in two packets?
|
||||
if scfg, ok := cryptoData[TagSCFG]; ok {
|
||||
h.serverConfig, err = parseServerConfig(scfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.serverConfig.IsExpired() {
|
||||
return qerr.CryptoServerConfigExpired
|
||||
}
|
||||
|
||||
// now that we have a server config, we can use its OBIT value to generate a client nonce
|
||||
if len(h.nonc) == 0 {
|
||||
err = h.generateClientNonce()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if proof, ok := cryptoData[TagPROF]; ok {
|
||||
h.proof = proof
|
||||
h.chloForSignature = h.lastSentCHLO
|
||||
}
|
||||
|
||||
if crt, ok := cryptoData[TagCERT]; ok {
|
||||
err := h.certManager.SetData(crt)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
|
||||
}
|
||||
|
||||
err = h.certManager.Verify(h.hostname)
|
||||
if err != nil {
|
||||
h.logger.Infof("Certificate validation failed: %s", err.Error())
|
||||
return qerr.ProofInvalid
|
||||
}
|
||||
}
|
||||
|
||||
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
|
||||
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
|
||||
if !validProof {
|
||||
h.logger.Infof("Server proof verification failed")
|
||||
return qerr.ProofInvalid
|
||||
}
|
||||
|
||||
h.serverVerified = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if !h.receivedSecurePacket {
|
||||
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
|
||||
}
|
||||
|
||||
if sno, ok := cryptoData[TagSNO]; ok {
|
||||
h.sno = sno
|
||||
}
|
||||
|
||||
serverPubs, ok := cryptoData[TagPUBS]
|
||||
if !ok {
|
||||
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||
}
|
||||
|
||||
verTag, ok := cryptoData[TagVER]
|
||||
if !ok {
|
||||
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
|
||||
}
|
||||
if !h.validateVersionList(verTag) {
|
||||
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
||||
}
|
||||
|
||||
nonce := append(h.nonc, h.sno...)
|
||||
|
||||
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leafCert := h.certManager.GetLeafCert()
|
||||
|
||||
h.forwardSecureAEAD, err = h.keyDerivation(
|
||||
true,
|
||||
ephermalSharedSecret,
|
||||
nonce,
|
||||
h.connID,
|
||||
h.lastSentCHLO,
|
||||
h.serverConfig.Get(),
|
||||
leafCert,
|
||||
nil,
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
|
||||
|
||||
params, err := readHelloMap(cryptoData)
|
||||
if err != nil {
|
||||
return nil, qerr.InvalidCryptoMessageParameter
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
|
||||
numNegotiatedVersions := len(h.negotiatedVersions)
|
||||
if numNegotiatedVersions == 0 {
|
||||
return true
|
||||
}
|
||||
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
|
||||
return false
|
||||
}
|
||||
|
||||
b := bytes.NewReader(verTags)
|
||||
for i := 0; i < numNegotiatedVersions; i++ {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil { // should never occur, since the length was already checked
|
||||
return false
|
||||
}
|
||||
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if h.forwardSecureAEAD != nil {
|
||||
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
return data, protocol.EncryptionForwardSecure, nil
|
||||
}
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
|
||||
if h.secureAEAD != nil {
|
||||
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||
h.receivedSecurePacket = true
|
||||
return data, protocol.EncryptionSecure, nil
|
||||
}
|
||||
if h.receivedSecurePacket {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
}
|
||||
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err != nil {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
return res, protocol.EncryptionUnencrypted, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
if h.forwardSecureAEAD != nil {
|
||||
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
|
||||
} else if h.secureAEAD != nil {
|
||||
return protocol.EncryptionSecure, h.secureAEAD
|
||||
} else {
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
switch encLevel {
|
||||
case protocol.EncryptionUnencrypted:
|
||||
return h.nullAEAD, nil
|
||||
case protocol.EncryptionSecure:
|
||||
if h.secureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupClient: no secureAEAD")
|
||||
}
|
||||
return h.secureAEAD, nil
|
||||
case protocol.EncryptionForwardSecure:
|
||||
if h.forwardSecureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
|
||||
}
|
||||
return h.forwardSecureAEAD, nil
|
||||
}
|
||||
return nil, errors.New("CryptoSetupClient: no encryption level specified")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
HandshakeComplete: h.forwardSecureAEAD != nil,
|
||||
PeerCertificates: h.certManager.GetChain(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
|
||||
h.mutex.Lock()
|
||||
if len(h.diversificationNonce) > 0 {
|
||||
defer h.mutex.Unlock()
|
||||
if !bytes.Equal(h.diversificationNonce, divNonce) {
|
||||
return errConflictingDiversificationNonces
|
||||
}
|
||||
return nil
|
||||
}
|
||||
h.diversificationNonce = divNonce
|
||||
h.mutex.Unlock()
|
||||
h.divNonceChan <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) sendCHLO() error {
|
||||
h.clientHelloCounter++
|
||||
if h.clientHelloCounter > protocol.MaxClientHellos {
|
||||
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos))
|
||||
}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
tags, err := h.getTags()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.addPadding(tags)
|
||||
message := HandshakeMessage{
|
||||
Tag: TagCHLO,
|
||||
Data: tags,
|
||||
}
|
||||
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
message.Write(b)
|
||||
|
||||
_, err = h.cryptoStream.Write(b.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.lastSentCHLO = b.Bytes()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
|
||||
tags := h.params.getHelloMap()
|
||||
tags[TagSNI] = []byte(h.hostname)
|
||||
tags[TagPDMD] = []byte("X509")
|
||||
|
||||
ccs := h.certManager.GetCommonCertificateHashes()
|
||||
if len(ccs) > 0 {
|
||||
tags[TagCCS] = ccs
|
||||
}
|
||||
|
||||
versionTag := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
|
||||
tags[TagVER] = versionTag
|
||||
|
||||
if len(h.stk) > 0 {
|
||||
tags[TagSTK] = h.stk
|
||||
}
|
||||
if len(h.sno) > 0 {
|
||||
tags[TagSNO] = h.sno
|
||||
}
|
||||
|
||||
if h.serverConfig != nil {
|
||||
tags[TagSCID] = h.serverConfig.ID
|
||||
|
||||
leafCert := h.certManager.GetLeafCert()
|
||||
if leafCert != nil {
|
||||
certHash, _ := h.certManager.GetLeafCertHash()
|
||||
xlct := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(xlct, certHash)
|
||||
|
||||
tags[TagNONC] = h.nonc
|
||||
tags[TagXLCT] = xlct
|
||||
tags[TagKEXS] = []byte("C255")
|
||||
tags[TagAEAD] = []byte("AESG")
|
||||
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
|
||||
}
|
||||
}
|
||||
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
|
||||
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
|
||||
var size int
|
||||
for _, tag := range tags {
|
||||
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
|
||||
}
|
||||
paddingSize := protocol.MinClientHelloSize - size
|
||||
if paddingSize > 0 {
|
||||
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
|
||||
if !h.serverVerified {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
leafCert := h.certManager.GetLeafCert()
|
||||
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) {
|
||||
var err error
|
||||
var nonce []byte
|
||||
if h.sno == nil {
|
||||
nonce = h.nonc
|
||||
} else {
|
||||
nonce = append(h.nonc, h.sno...)
|
||||
}
|
||||
|
||||
h.secureAEAD, err = h.keyDerivation(
|
||||
false,
|
||||
h.serverConfig.sharedSecret,
|
||||
nonce,
|
||||
h.connID,
|
||||
h.lastSentCHLO,
|
||||
h.serverConfig.Get(),
|
||||
leafCert,
|
||||
h.diversificationNonce,
|
||||
protocol.PerspectiveClient,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||
h.handshakeEvent <- struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupClient) generateClientNonce() error {
|
||||
if len(h.nonc) > 0 {
|
||||
return errClientNonceAlreadyExists
|
||||
}
|
||||
|
||||
nonc := make([]byte, 32)
|
||||
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix()))
|
||||
|
||||
if len(h.serverConfig.obit) != 8 {
|
||||
return errNoObitForClientNonce
|
||||
}
|
||||
|
||||
copy(nonc[4:12], h.serverConfig.obit)
|
||||
|
||||
_, err := rand.Read(nonc[12:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.nonc = nonc
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,470 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// QuicCryptoKeyDerivationFunction is used for key derivation
|
||||
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
|
||||
|
||||
// KeyExchangeFunction is used to make a new KEX
|
||||
type KeyExchangeFunction func() (crypto.KeyExchange, error)
|
||||
|
||||
// The CryptoSetupServer handles all things crypto for the Session
|
||||
type cryptoSetupServer struct {
|
||||
mutex sync.RWMutex
|
||||
|
||||
connID protocol.ConnectionID
|
||||
remoteAddr net.Addr
|
||||
scfg *ServerConfig
|
||||
diversificationNonce []byte
|
||||
|
||||
version protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
|
||||
acceptSTKCallback func(net.Addr, *Cookie) bool
|
||||
|
||||
nullAEAD crypto.AEAD
|
||||
secureAEAD crypto.AEAD
|
||||
forwardSecureAEAD crypto.AEAD
|
||||
receivedForwardSecurePacket bool
|
||||
receivedSecurePacket bool
|
||||
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
|
||||
|
||||
receivedParams bool
|
||||
paramsChan chan<- TransportParameters
|
||||
handshakeEvent chan<- struct{}
|
||||
handshakeComplete chan<- struct{}
|
||||
|
||||
keyDerivation QuicCryptoKeyDerivationFunction
|
||||
keyExchange KeyExchangeFunction
|
||||
|
||||
cryptoStream io.ReadWriter
|
||||
|
||||
params *TransportParameters
|
||||
|
||||
sni string // need to fill out the ConnectionState
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetupServer{}
|
||||
|
||||
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
|
||||
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
|
||||
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
|
||||
|
||||
// NewCryptoSetup creates a new CryptoSetup instance for a server
|
||||
func NewCryptoSetup(
|
||||
cryptoStream io.ReadWriter,
|
||||
connID protocol.ConnectionID,
|
||||
remoteAddr net.Addr,
|
||||
version protocol.VersionNumber,
|
||||
divNonce []byte,
|
||||
scfg *ServerConfig,
|
||||
params *TransportParameters,
|
||||
supportedVersions []protocol.VersionNumber,
|
||||
acceptSTK func(net.Addr, *Cookie) bool,
|
||||
paramsChan chan<- TransportParameters,
|
||||
handshakeEvent chan<- struct{},
|
||||
handshakeComplete chan<- struct{},
|
||||
logger utils.Logger,
|
||||
) (CryptoSetup, error) {
|
||||
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cryptoSetupServer{
|
||||
cryptoStream: cryptoStream,
|
||||
connID: connID,
|
||||
remoteAddr: remoteAddr,
|
||||
version: version,
|
||||
supportedVersions: supportedVersions,
|
||||
diversificationNonce: divNonce,
|
||||
scfg: scfg,
|
||||
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
|
||||
keyExchange: getEphermalKEX,
|
||||
nullAEAD: nullAEAD,
|
||||
params: params,
|
||||
acceptSTKCallback: acceptSTK,
|
||||
sentSHLO: make(chan struct{}),
|
||||
paramsChan: paramsChan,
|
||||
handshakeEvent: handshakeEvent,
|
||||
handshakeComplete: handshakeComplete,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleCryptoStream reads and writes messages on the crypto stream
|
||||
func (h *cryptoSetupServer) RunHandshake() error {
|
||||
for {
|
||||
var chloData bytes.Buffer
|
||||
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
|
||||
if err != nil {
|
||||
return qerr.HandshakeFailed
|
||||
}
|
||||
if message.Tag != TagCHLO {
|
||||
return qerr.InvalidCryptoMessageType
|
||||
}
|
||||
|
||||
h.logger.Debugf("Got %s", message)
|
||||
done, err := h.handleMessage(chloData.Bytes(), message.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
|
||||
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
|
||||
return false, ErrNSTPExperiment
|
||||
}
|
||||
|
||||
sniSlice, ok := cryptoData[TagSNI]
|
||||
if !ok {
|
||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
||||
}
|
||||
sni := string(sniSlice)
|
||||
if sni == "" {
|
||||
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
|
||||
}
|
||||
h.sni = sni
|
||||
|
||||
// prevent version downgrade attacks
|
||||
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
|
||||
verSlice, ok := cryptoData[TagVER]
|
||||
if !ok {
|
||||
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")
|
||||
}
|
||||
if len(verSlice) != 4 {
|
||||
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")
|
||||
}
|
||||
ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice))
|
||||
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
|
||||
if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
|
||||
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
var err error
|
||||
|
||||
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
params, err := readHelloMap(cryptoData)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// blocks until the session has received the parameters
|
||||
if !h.receivedParams {
|
||||
h.receivedParams = true
|
||||
h.paramsChan <- *params
|
||||
}
|
||||
|
||||
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
|
||||
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
|
||||
reply, err = h.handleCHLO(sni, chloData, cryptoData)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if _, err := h.cryptoStream.Write(reply); err != nil {
|
||||
return false, err
|
||||
}
|
||||
h.handshakeEvent <- struct{}{}
|
||||
close(h.sentSHLO)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// We have an inchoate or non-matching CHLO, we now send a rejection
|
||||
reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, err = h.cryptoStream.Write(reply)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Open a message
|
||||
func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
if h.forwardSecureAEAD != nil {
|
||||
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
|
||||
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
|
||||
h.receivedForwardSecurePacket = true
|
||||
// wait for the send on the handshakeEvent chan
|
||||
<-h.sentSHLO
|
||||
close(h.handshakeComplete)
|
||||
}
|
||||
return res, protocol.EncryptionForwardSecure, nil
|
||||
}
|
||||
if h.receivedForwardSecurePacket {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
}
|
||||
if h.secureAEAD != nil {
|
||||
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err == nil {
|
||||
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
|
||||
h.receivedSecurePacket = true
|
||||
return res, protocol.EncryptionSecure, nil
|
||||
}
|
||||
if h.receivedSecurePacket {
|
||||
return nil, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
}
|
||||
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
|
||||
if err != nil {
|
||||
return res, protocol.EncryptionUnspecified, err
|
||||
}
|
||||
return res, protocol.EncryptionUnencrypted, err
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
if h.forwardSecureAEAD != nil {
|
||||
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
|
||||
}
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
if h.secureAEAD != nil {
|
||||
return protocol.EncryptionSecure, h.secureAEAD
|
||||
}
|
||||
return protocol.EncryptionUnencrypted, h.nullAEAD
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
switch encLevel {
|
||||
case protocol.EncryptionUnencrypted:
|
||||
return h.nullAEAD, nil
|
||||
case protocol.EncryptionSecure:
|
||||
if h.secureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupServer: no secureAEAD")
|
||||
}
|
||||
return h.secureAEAD, nil
|
||||
case protocol.EncryptionForwardSecure:
|
||||
if h.forwardSecureAEAD == nil {
|
||||
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
|
||||
}
|
||||
return h.forwardSecureAEAD, nil
|
||||
}
|
||||
return nil, errors.New("CryptoSetupServer: no encryption level specified")
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
|
||||
if _, ok := cryptoData[TagPUBS]; !ok {
|
||||
return true
|
||||
}
|
||||
scid, ok := cryptoData[TagSCID]
|
||||
if !ok || !bytes.Equal(h.scfg.ID, scid) {
|
||||
return true
|
||||
}
|
||||
xlctTag, ok := cryptoData[TagXLCT]
|
||||
if !ok || len(xlctTag) != 8 {
|
||||
return true
|
||||
}
|
||||
xlct := binary.LittleEndian.Uint64(xlctTag)
|
||||
if crypto.HashCert(cert) != xlct {
|
||||
return true
|
||||
}
|
||||
return !h.acceptSTK(cryptoData[TagSTK])
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
|
||||
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
|
||||
if err != nil {
|
||||
h.logger.Debugf("STK invalid: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
return h.acceptSTKCallback(h.remoteAddr, stk)
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
replyMap := map[Tag][]byte{
|
||||
TagSCFG: h.scfg.Get(),
|
||||
TagSTK: token,
|
||||
TagSVID: []byte("quic-go"),
|
||||
}
|
||||
|
||||
if h.acceptSTK(cryptoData[TagSTK]) {
|
||||
proof, err := h.scfg.Sign(sni, chlo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commonSetHashes := cryptoData[TagCCS]
|
||||
cachedCertsHashes := cryptoData[TagCCRT]
|
||||
|
||||
certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Token was valid, send more details
|
||||
replyMap[TagPROF] = proof
|
||||
replyMap[TagCERT] = certCompressed
|
||||
}
|
||||
|
||||
message := HandshakeMessage{
|
||||
Tag: TagREJ,
|
||||
Data: replyMap,
|
||||
}
|
||||
|
||||
var serverReply bytes.Buffer
|
||||
message.Write(&serverReply)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
return serverReply.Bytes(), nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) {
|
||||
// We have a CHLO matching our server config, we can continue with the 0-RTT handshake
|
||||
sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverNonce := make([]byte, 32)
|
||||
if _, err = rand.Read(serverNonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientNonce := cryptoData[TagNONC]
|
||||
err = h.validateClientNonce(clientNonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aead := cryptoData[TagAEAD]
|
||||
if !bytes.Equal(aead, []byte("AESG")) {
|
||||
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
|
||||
}
|
||||
|
||||
kexs := cryptoData[TagKEXS]
|
||||
if !bytes.Equal(kexs, []byte("C255")) {
|
||||
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
|
||||
}
|
||||
|
||||
h.secureAEAD, err = h.keyDerivation(
|
||||
false,
|
||||
sharedSecret,
|
||||
clientNonce,
|
||||
h.connID,
|
||||
data,
|
||||
h.scfg.Get(),
|
||||
certUncompressed,
|
||||
h.diversificationNonce,
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for secure encryption.")
|
||||
h.handshakeEvent <- struct{}{}
|
||||
|
||||
// Generate a new curve instance to derive the forward secure key
|
||||
var fsNonce bytes.Buffer
|
||||
fsNonce.Write(clientNonce)
|
||||
fsNonce.Write(serverNonce)
|
||||
ephermalKex, err := h.keyExchange()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.forwardSecureAEAD, err = h.keyDerivation(
|
||||
true,
|
||||
ephermalSharedSecret,
|
||||
fsNonce.Bytes(),
|
||||
h.connID,
|
||||
data,
|
||||
h.scfg.Get(),
|
||||
certUncompressed,
|
||||
nil,
|
||||
protocol.PerspectiveServer,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
|
||||
|
||||
replyMap := h.params.getHelloMap()
|
||||
// add crypto parameters
|
||||
verTag := &bytes.Buffer{}
|
||||
for _, v := range h.supportedVersions {
|
||||
utils.BigEndian.WriteUint32(verTag, uint32(v))
|
||||
}
|
||||
replyMap[TagPUBS] = ephermalKex.PublicKey()
|
||||
replyMap[TagSNO] = serverNonce
|
||||
replyMap[TagVER] = verTag.Bytes()
|
||||
|
||||
// note that the SHLO *has* to fit into one packet
|
||||
message := HandshakeMessage{
|
||||
Tag: TagSHLO,
|
||||
Data: replyMap,
|
||||
}
|
||||
var reply bytes.Buffer
|
||||
message.Write(&reply)
|
||||
h.logger.Debugf("Sending %s", message)
|
||||
return reply.Bytes(), nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
return ConnectionState{
|
||||
ServerName: h.sni,
|
||||
HandshakeComplete: h.receivedForwardSecurePacket,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
|
||||
if len(nonce) != 32 {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
|
||||
}
|
||||
if !bytes.Equal(nonce[4:12], h.scfg.obit) {
|
||||
return qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,734 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/mocks/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockKEX struct {
|
||||
ephermal bool
|
||||
sharedKeyError error
|
||||
}
|
||||
|
||||
func (m *mockKEX) PublicKey() []byte {
|
||||
if m.ephermal {
|
||||
return []byte("ephermal pub")
|
||||
}
|
||||
return []byte("initial public")
|
||||
}
|
||||
|
||||
func (m *mockKEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
|
||||
if m.sharedKeyError != nil {
|
||||
return nil, m.sharedKeyError
|
||||
}
|
||||
if m.ephermal {
|
||||
return []byte("shared ephermal"), nil
|
||||
}
|
||||
return []byte("shared key"), nil
|
||||
}
|
||||
|
||||
type mockSigner struct {
|
||||
gotCHLO bool
|
||||
}
|
||||
|
||||
func (s *mockSigner) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
|
||||
if len(chlo) > 0 {
|
||||
s.gotCHLO = true
|
||||
}
|
||||
return []byte("proof"), nil
|
||||
}
|
||||
func (*mockSigner) GetCertsCompressed(sni string, common, cached []byte) ([]byte, error) {
|
||||
return []byte("certcompressed"), nil
|
||||
}
|
||||
func (*mockSigner) GetLeafCert(sni string) ([]byte, error) {
|
||||
return []byte("certuncompressed"), nil
|
||||
}
|
||||
|
||||
func mockQuicCryptoKeyDerivation(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
|
||||
return mockcrypto.NewMockAEAD(mockCtrl), nil
|
||||
}
|
||||
|
||||
type mockStream struct {
|
||||
unblockRead chan struct{}
|
||||
dataToRead bytes.Buffer
|
||||
dataWritten bytes.Buffer
|
||||
}
|
||||
|
||||
var _ io.ReadWriter = &mockStream{}
|
||||
|
||||
var errMockStreamClosing = errors.New("mock stream closing")
|
||||
|
||||
func newMockStream() *mockStream {
|
||||
return &mockStream{unblockRead: make(chan struct{})}
|
||||
}
|
||||
|
||||
// call Close to make Read return
|
||||
func (s *mockStream) Read(p []byte) (int, error) {
|
||||
n, _ := s.dataToRead.Read(p)
|
||||
if n == 0 { // block if there's no data
|
||||
<-s.unblockRead
|
||||
return 0, errMockStreamClosing
|
||||
}
|
||||
return n, nil // never return an EOF
|
||||
}
|
||||
|
||||
func (s *mockStream) Write(p []byte) (int, error) {
|
||||
return s.dataWritten.Write(p)
|
||||
}
|
||||
|
||||
func (s *mockStream) close() {
|
||||
close(s.unblockRead)
|
||||
}
|
||||
|
||||
type mockCookieProtector struct {
|
||||
decodeErr error
|
||||
}
|
||||
|
||||
var _ cookieProtector = &mockCookieProtector{}
|
||||
|
||||
func (mockCookieProtector) NewToken(sourceAddr []byte) ([]byte, error) {
|
||||
return append([]byte("token "), sourceAddr...), nil
|
||||
}
|
||||
|
||||
func (s mockCookieProtector) DecodeToken(data []byte) ([]byte, error) {
|
||||
if s.decodeErr != nil {
|
||||
return nil, s.decodeErr
|
||||
}
|
||||
if len(data) < 6 {
|
||||
return nil, errors.New("token too short")
|
||||
}
|
||||
return data[6:], nil
|
||||
}
|
||||
|
||||
var _ = Describe("Server Crypto Setup", func() {
|
||||
var (
|
||||
kex *mockKEX
|
||||
signer *mockSigner
|
||||
scfg *ServerConfig
|
||||
cs *cryptoSetupServer
|
||||
stream *mockStream
|
||||
paramsChan chan TransportParameters
|
||||
handshakeEvent chan struct{}
|
||||
handshakeComplete chan struct{}
|
||||
nonce32 []byte
|
||||
versionTag []byte
|
||||
validSTK []byte
|
||||
aead []byte
|
||||
kexs []byte
|
||||
version protocol.VersionNumber
|
||||
supportedVersions []protocol.VersionNumber
|
||||
sourceAddrValid bool
|
||||
)
|
||||
|
||||
const (
|
||||
expectedInitialNonceLen = 32
|
||||
expectedFSNonceLen = 64
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
|
||||
|
||||
// use a buffered channel here, so that we can parse a CHLO without having to receive the TransportParameters to avoid blocking
|
||||
paramsChan = make(chan TransportParameters, 1)
|
||||
handshakeEvent = make(chan struct{}, 2)
|
||||
handshakeComplete = make(chan struct{})
|
||||
stream = newMockStream()
|
||||
kex = &mockKEX{}
|
||||
signer = &mockSigner{}
|
||||
scfg, err = NewServerConfig(kex, signer)
|
||||
nonce32 = make([]byte, 32)
|
||||
aead = []byte("AESG")
|
||||
kexs = []byte("C255")
|
||||
copy(nonce32[4:12], scfg.obit) // set the OBIT value at the right position
|
||||
versionTag = make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(versionTag, uint32(protocol.VersionWhatever))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
||||
supportedVersions = []protocol.VersionNumber{version, 98, 99}
|
||||
csInt, err := NewCryptoSetup(
|
||||
stream,
|
||||
protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
remoteAddr,
|
||||
version,
|
||||
make([]byte, 32), // div nonce
|
||||
scfg,
|
||||
&TransportParameters{IdleTimeout: protocol.DefaultIdleTimeout},
|
||||
supportedVersions,
|
||||
nil,
|
||||
paramsChan,
|
||||
handshakeEvent,
|
||||
handshakeComplete,
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
cs = csInt.(*cryptoSetupServer)
|
||||
cs.scfg.cookieGenerator.cookieProtector = &mockCookieProtector{}
|
||||
validSTK, err = cs.scfg.cookieGenerator.NewToken(remoteAddr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
sourceAddrValid = true
|
||||
cs.acceptSTKCallback = func(_ net.Addr, _ *Cookie) bool { return sourceAddrValid }
|
||||
cs.keyDerivation = mockQuicCryptoKeyDerivation
|
||||
cs.keyExchange = func() (crypto.KeyExchange, error) { return &mockKEX{ephermal: true}, nil }
|
||||
cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl)
|
||||
cs.cryptoStream = stream
|
||||
})
|
||||
|
||||
Context("when responding to client messages", func() {
|
||||
var cert []byte
|
||||
var xlct []byte
|
||||
var fullCHLO map[Tag][]byte
|
||||
|
||||
BeforeEach(func() {
|
||||
xlct = make([]byte, 8)
|
||||
var err error
|
||||
cert, err = cs.scfg.certChain.GetLeafCert("")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
binary.LittleEndian.PutUint64(xlct, crypto.HashCert(cert))
|
||||
fullCHLO = map[Tag][]byte{
|
||||
TagSCID: scfg.ID,
|
||||
TagSNI: []byte("quic.clemente.io"),
|
||||
TagNONC: nonce32,
|
||||
TagSTK: validSTK,
|
||||
TagXLCT: xlct,
|
||||
TagAEAD: aead,
|
||||
TagKEXS: kexs,
|
||||
TagPUBS: bytes.Repeat([]byte{'e'}, 31),
|
||||
TagVER: versionTag,
|
||||
}
|
||||
})
|
||||
|
||||
It("doesn't support Chrome's no STOP_WAITING experiment", func() {
|
||||
HandshakeMessage{
|
||||
Tag: TagCHLO,
|
||||
Data: map[Tag][]byte{
|
||||
TagNSTP: []byte("foobar"),
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(ErrNSTPExperiment))
|
||||
})
|
||||
|
||||
It("reads the transport parameters sent by the client", func() {
|
||||
sourceAddrValid = true
|
||||
fullCHLO[TagICSL] = []byte{0x37, 0x13, 0, 0}
|
||||
_, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), fullCHLO)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var params TransportParameters
|
||||
Expect(paramsChan).To(Receive(¶ms))
|
||||
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
||||
})
|
||||
|
||||
It("generates REJ messages", func() {
|
||||
sourceAddrValid = false
|
||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response).To(HavePrefix("REJ"))
|
||||
Expect(response).To(ContainSubstring("initial public"))
|
||||
Expect(response).ToNot(ContainSubstring("certcompressed"))
|
||||
Expect(response).ToNot(ContainSubstring("proof"))
|
||||
Expect(signer.gotCHLO).To(BeFalse())
|
||||
})
|
||||
|
||||
It("REJ messages don't include cert or proof without STK", func() {
|
||||
sourceAddrValid = false
|
||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response).To(HavePrefix("REJ"))
|
||||
Expect(response).ToNot(ContainSubstring("certcompressed"))
|
||||
Expect(response).ToNot(ContainSubstring("proof"))
|
||||
Expect(signer.gotCHLO).To(BeFalse())
|
||||
})
|
||||
|
||||
It("REJ messages include cert and proof with valid STK", func() {
|
||||
sourceAddrValid = true
|
||||
response, err := cs.handleInchoateCHLO("", bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize), map[Tag][]byte{
|
||||
TagSTK: validSTK,
|
||||
TagSNI: []byte("foo"),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response).To(HavePrefix("REJ"))
|
||||
Expect(response).To(ContainSubstring("certcompressed"))
|
||||
Expect(response).To(ContainSubstring("proof"))
|
||||
Expect(signer.gotCHLO).To(BeTrue())
|
||||
})
|
||||
|
||||
It("generates SHLO messages", func() {
|
||||
var checkedSecure, checkedForwardSecure bool
|
||||
cs.keyDerivation = func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
|
||||
if forwardSecure {
|
||||
Expect(nonces).To(HaveLen(expectedFSNonceLen))
|
||||
checkedForwardSecure = true
|
||||
Expect(sharedSecret).To(Equal([]byte("shared ephermal")))
|
||||
} else {
|
||||
Expect(nonces).To(HaveLen(expectedInitialNonceLen))
|
||||
Expect(sharedSecret).To(Equal([]byte("shared key")))
|
||||
checkedSecure = true
|
||||
}
|
||||
return mockcrypto.NewMockAEAD(mockCtrl), nil
|
||||
}
|
||||
|
||||
response, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
|
||||
TagPUBS: []byte("pubs-c"),
|
||||
TagNONC: nonce32,
|
||||
TagAEAD: aead,
|
||||
TagKEXS: kexs,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response).To(HavePrefix("SHLO"))
|
||||
message, err := ParseHandshakeMessage(bytes.NewReader(response))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(message.Data).To(HaveKeyWithValue(TagPUBS, []byte("ephermal pub")))
|
||||
Expect(message.Data).To(HaveKey(TagSNO))
|
||||
Expect(message.Data).To(HaveKey(TagVER))
|
||||
Expect(message.Data[TagVER]).To(HaveLen(4 * len(supportedVersions)))
|
||||
for _, v := range supportedVersions {
|
||||
b := &bytes.Buffer{}
|
||||
utils.BigEndian.WriteUint32(b, uint32(v))
|
||||
Expect(message.Data[TagVER]).To(ContainSubstring(b.String()))
|
||||
}
|
||||
Expect(checkedSecure).To(BeTrue())
|
||||
Expect(checkedForwardSecure).To(BeTrue())
|
||||
})
|
||||
|
||||
It("handles long handshake", func() {
|
||||
HandshakeMessage{
|
||||
Tag: TagCHLO,
|
||||
Data: map[Tag][]byte{
|
||||
TagSNI: []byte("quic.clemente.io"),
|
||||
TagSTK: validSTK,
|
||||
TagPAD: bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
|
||||
TagVER: versionTag,
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ"))
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring("SHLO"))
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
||||
Expect(handshakeComplete).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("rejects client nonces that have the wrong length", func() {
|
||||
fullCHLO[TagNONC] = []byte("too short client nonce")
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")))
|
||||
})
|
||||
|
||||
It("rejects client nonces that have the wrong OBIT value", func() {
|
||||
fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")))
|
||||
})
|
||||
|
||||
It("errors if it can't calculate a shared key", func() {
|
||||
testErr := errors.New("test error")
|
||||
kex.sharedKeyError = testErr
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
It("handles 0-RTT handshake", func() {
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO"))
|
||||
Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ"))
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to forward secure
|
||||
Expect(handshakeComplete).ToNot(BeClosed())
|
||||
})
|
||||
|
||||
It("recognizes inchoate CHLOs missing SCID", func() {
|
||||
delete(fullCHLO, TagSCID)
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("recognizes inchoate CHLOs missing PUBS", func() {
|
||||
delete(fullCHLO, TagPUBS)
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("recognizes inchoate CHLOs with missing XLCT", func() {
|
||||
delete(fullCHLO, TagXLCT)
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("recognizes inchoate CHLOs with wrong length XLCT", func() {
|
||||
fullCHLO[TagXLCT] = bytes.Repeat([]byte{'f'}, 7) // should be 8 bytes
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("recognizes inchoate CHLOs with wrong XLCT", func() {
|
||||
fullCHLO[TagXLCT] = bytes.Repeat([]byte{'f'}, 8)
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("recognizes inchoate CHLOs with an invalid STK", func() {
|
||||
testErr := errors.New("STK invalid")
|
||||
cs.scfg.cookieGenerator.cookieProtector.(*mockCookieProtector).decodeErr = testErr
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("recognizes proper CHLOs", func() {
|
||||
Expect(cs.isInchoateCHLO(fullCHLO, cert)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("rejects CHLOs without the version tag", func() {
|
||||
HandshakeMessage{
|
||||
Tag: TagCHLO,
|
||||
Data: map[Tag][]byte{
|
||||
TagSCID: scfg.ID,
|
||||
TagSNI: []byte("quic.clemente.io"),
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")))
|
||||
})
|
||||
|
||||
It("rejects CHLOs with a version tag that has the wrong length", func() {
|
||||
fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")))
|
||||
})
|
||||
|
||||
It("detects version downgrade attacks", func() {
|
||||
highestSupportedVersion := supportedVersions[len(supportedVersions)-1]
|
||||
lowestSupportedVersion := supportedVersions[0]
|
||||
Expect(highestSupportedVersion).ToNot(Equal(lowestSupportedVersion))
|
||||
cs.version = highestSupportedVersion
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, uint32(lowestSupportedVersion))
|
||||
fullCHLO[TagVER] = b
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")))
|
||||
})
|
||||
|
||||
It("accepts a non-matching version tag in the CHLO, if it is an unsupported version", func() {
|
||||
supportedVersion := protocol.SupportedVersions[0]
|
||||
unsupportedVersion := supportedVersion + 1000
|
||||
Expect(protocol.IsSupportedVersion(supportedVersions, unsupportedVersion)).To(BeFalse())
|
||||
cs.version = supportedVersion
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, uint32(unsupportedVersion))
|
||||
fullCHLO[TagVER] = b
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors if the AEAD tag is missing", func() {
|
||||
delete(fullCHLO, TagAEAD)
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
||||
})
|
||||
|
||||
It("errors if the AEAD tag has the wrong value", func() {
|
||||
fullCHLO[TagAEAD] = []byte("wrong")
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
||||
})
|
||||
|
||||
It("errors if the KEXS tag is missing", func() {
|
||||
delete(fullCHLO, TagKEXS)
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
||||
})
|
||||
|
||||
It("errors if the KEXS tag has the wrong value", func() {
|
||||
fullCHLO[TagKEXS] = []byte("wrong")
|
||||
HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")))
|
||||
})
|
||||
})
|
||||
|
||||
It("errors without SNI", func() {
|
||||
HandshakeMessage{
|
||||
Tag: TagCHLO,
|
||||
Data: map[Tag][]byte{
|
||||
TagSTK: validSTK,
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
|
||||
})
|
||||
|
||||
It("errors with empty SNI", func() {
|
||||
HandshakeMessage{
|
||||
Tag: TagCHLO,
|
||||
Data: map[Tag][]byte{
|
||||
TagSTK: validSTK,
|
||||
TagSNI: nil,
|
||||
},
|
||||
}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required"))
|
||||
})
|
||||
|
||||
It("errors with invalid message", func() {
|
||||
stream.dataToRead.Write([]byte("invalid message"))
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.HandshakeFailed))
|
||||
})
|
||||
|
||||
It("errors with non-CHLO message", func() {
|
||||
HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead)
|
||||
err := cs.RunHandshake()
|
||||
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
|
||||
})
|
||||
|
||||
Context("escalating crypto", func() {
|
||||
doCHLO := func() {
|
||||
_, err := cs.handleCHLO("", []byte("chlo-data"), map[Tag][]byte{
|
||||
TagPUBS: []byte("pubs-c"),
|
||||
TagNONC: nonce32,
|
||||
TagAEAD: aead,
|
||||
TagKEXS: kexs,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handshakeEvent).To(Receive()) // for the switch to secure
|
||||
close(cs.sentSHLO)
|
||||
}
|
||||
|
||||
Context("null encryption", func() {
|
||||
It("is used initially", func() {
|
||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(10), []byte{}).Return([]byte("foobar signed"))
|
||||
enc, sealer := cs.GetSealer()
|
||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
||||
d := sealer.Seal(nil, []byte("foobar"), 10, []byte{})
|
||||
Expect(d).To(Equal([]byte("foobar signed")))
|
||||
})
|
||||
|
||||
It("is used for the crypto stream", func() {
|
||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
|
||||
enc, sealer := cs.GetSealerForCryptoStream()
|
||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
||||
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
|
||||
})
|
||||
|
||||
It("is accepted initially", func() {
|
||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(5), []byte{}).Return([]byte("decrypted"), nil)
|
||||
d, enc, err := cs.Open(nil, []byte("unencrypted"), 5, []byte{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(d).To(Equal([]byte("decrypted")))
|
||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
||||
})
|
||||
|
||||
It("errors if the has the wrong hash", func() {
|
||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("not unencrypted"), protocol.PacketNumber(5), []byte{}).Return(nil, errors.New("authentication failed"))
|
||||
_, enc, err := cs.Open(nil, []byte("not unencrypted"), 5, []byte{})
|
||||
Expect(err).To(MatchError("authentication failed"))
|
||||
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
|
||||
})
|
||||
|
||||
It("is still accepted after CHLO", func() {
|
||||
doCHLO()
|
||||
// it tries forward secure and secure decryption first
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
|
||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
|
||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{})
|
||||
Expect(cs.secureAEAD).ToNot(BeNil())
|
||||
_, enc, err := cs.Open(nil, []byte("unencrypted"), 99, []byte{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(enc).To(Equal(protocol.EncryptionUnencrypted))
|
||||
})
|
||||
|
||||
It("is not accepted after receiving secure packet", func() {
|
||||
doCHLO()
|
||||
// first receive a secure packet
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
|
||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
|
||||
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
|
||||
Expect(enc).To(Equal(protocol.EncryptionSecure))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(d).To(Equal([]byte("decrypted")))
|
||||
// now receive an unencrypted packet
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
|
||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("unencrypted"), protocol.PacketNumber(99), []byte{}).Return(nil, errors.New("authentication failed"))
|
||||
_, enc, err = cs.Open(nil, []byte("unencrypted"), 99, []byte{})
|
||||
Expect(err).To(MatchError("authentication failed"))
|
||||
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
|
||||
})
|
||||
|
||||
It("is not used after CHLO", func() {
|
||||
doCHLO()
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(0), []byte{})
|
||||
enc, sealer := cs.GetSealer()
|
||||
Expect(enc).ToNot(Equal(protocol.EncryptionUnencrypted))
|
||||
sealer.Seal(nil, []byte("foobar"), 0, []byte{})
|
||||
})
|
||||
})
|
||||
|
||||
Context("initial encryption", func() {
|
||||
It("is accepted after CHLO", func() {
|
||||
doCHLO()
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return(nil, errors.New("authentication failed"))
|
||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(98), []byte{}).Return([]byte("decrypted"), nil)
|
||||
d, enc, err := cs.Open(nil, []byte("encrypted"), 98, []byte{})
|
||||
Expect(enc).To(Equal(protocol.EncryptionSecure))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(d).To(Equal([]byte("decrypted")))
|
||||
})
|
||||
|
||||
It("is not accepted after receiving forward secure packet", func() {
|
||||
doCHLO()
|
||||
// receive a forward secure packet
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
|
||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// receive a secure packet
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(12), []byte{}).Return(nil, errors.New("authentication failed"))
|
||||
_, enc, err := cs.Open(nil, []byte("encrypted"), 12, []byte{})
|
||||
Expect(err).To(MatchError("authentication failed"))
|
||||
Expect(enc).To(Equal(protocol.EncryptionUnspecified))
|
||||
})
|
||||
|
||||
It("is used for the crypto stream", func() {
|
||||
doCHLO()
|
||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(1), []byte{}).Return([]byte("foobar crypto stream"))
|
||||
enc, sealer := cs.GetSealerForCryptoStream()
|
||||
Expect(enc).To(Equal(protocol.EncryptionSecure))
|
||||
d := sealer.Seal(nil, []byte("foobar"), 1, []byte{})
|
||||
Expect(d).To(Equal([]byte("foobar crypto stream")))
|
||||
})
|
||||
})
|
||||
|
||||
Context("forward secure encryption", func() {
|
||||
It("is used after the CHLO", func() {
|
||||
doCHLO()
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(20), []byte{}).Return([]byte("foobar forward sec"))
|
||||
enc, sealer := cs.GetSealer()
|
||||
Expect(enc).To(Equal(protocol.EncryptionForwardSecure))
|
||||
d := sealer.Seal(nil, []byte("foobar"), 20, []byte{})
|
||||
Expect(d).To(Equal([]byte("foobar forward sec")))
|
||||
})
|
||||
|
||||
It("regards the handshake as complete once it receives a forward encrypted packet", func() {
|
||||
doCHLO()
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(200), []byte{})
|
||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 200, []byte{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(handshakeComplete).To(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("reporting the connection state", func() {
|
||||
It("reports before the handshake completes", func() {
|
||||
cs.sni = "server name"
|
||||
state := cs.ConnectionState()
|
||||
Expect(state.HandshakeComplete).To(BeFalse())
|
||||
Expect(state.ServerName).To(Equal("server name"))
|
||||
})
|
||||
|
||||
It("reports after the handshake completes", func() {
|
||||
doCHLO()
|
||||
// receive a forward secure packet
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("forward secure encrypted"), protocol.PacketNumber(11), []byte{})
|
||||
_, _, err := cs.Open(nil, []byte("forward secure encrypted"), 11, []byte{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
state := cs.ConnectionState()
|
||||
Expect(state.HandshakeComplete).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("forcing encryption levels", func() {
|
||||
It("forces null encryption", func() {
|
||||
cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(11), []byte{}).Return([]byte("foobar unencrypted"))
|
||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
d := sealer.Seal(nil, []byte("foobar"), 11, []byte{})
|
||||
Expect(d).To(Equal([]byte("foobar unencrypted")))
|
||||
})
|
||||
|
||||
It("forces initial encryption", func() {
|
||||
doCHLO()
|
||||
cs.secureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(12), []byte{}).Return([]byte("foobar secure"))
|
||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
d := sealer.Seal(nil, []byte("foobar"), 12, []byte{})
|
||||
Expect(d).To(Equal([]byte("foobar secure")))
|
||||
})
|
||||
|
||||
It("errors if no AEAD for initial encryption is available", func() {
|
||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure)
|
||||
Expect(err).To(MatchError("CryptoSetupServer: no secureAEAD"))
|
||||
Expect(sealer).To(BeNil())
|
||||
})
|
||||
|
||||
It("forces forward-secure encryption", func() {
|
||||
doCHLO()
|
||||
cs.forwardSecureAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(13), []byte{}).Return([]byte("foobar forward sec"))
|
||||
sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
d := sealer.Seal(nil, []byte("foobar"), 13, []byte{})
|
||||
Expect(d).To(Equal([]byte("foobar forward sec")))
|
||||
})
|
||||
|
||||
It("errors of no AEAD for forward-secure encryption is available", func() {
|
||||
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure)
|
||||
Expect(err).To(MatchError("CryptoSetupServer: no forwardSecureAEAD"))
|
||||
Expect(seal).To(BeNil())
|
||||
})
|
||||
|
||||
It("errors if no encryption level is specified", func() {
|
||||
seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified)
|
||||
Expect(err).To(MatchError("CryptoSetupServer: no encryption level specified"))
|
||||
Expect(seal).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("STK verification and creation", func() {
|
||||
It("requires STK", func() {
|
||||
sourceAddrValid = false
|
||||
done, err := cs.handleMessage(
|
||||
bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
|
||||
map[Tag][]byte{
|
||||
TagSNI: []byte("foo"),
|
||||
TagVER: versionTag,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(done).To(BeFalse())
|
||||
Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK)))
|
||||
Expect(cs.sni).To(Equal("foo"))
|
||||
})
|
||||
|
||||
It("works with proper STK", func() {
|
||||
sourceAddrValid = true
|
||||
done, err := cs.handleMessage(
|
||||
bytes.Repeat([]byte{'a'}, protocol.MinClientHelloSize),
|
||||
map[Tag][]byte{
|
||||
TagSNI: []byte("foo"),
|
||||
TagVER: versionTag,
|
||||
},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(done).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
|
@ -57,7 +57,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("returns Handshake() when an error occurs", func() {
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
|
@ -87,7 +87,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("returns Handshake() when handling a message fails", func() {
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
|
@ -116,7 +116,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("returns Handshake() when it is closed", func() {
|
||||
_, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
|
@ -162,9 +162,9 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
}
|
||||
|
||||
handshake := func(
|
||||
client CryptoSetupTLS,
|
||||
client CryptoSetup,
|
||||
cChunkChan <-chan chunk,
|
||||
server CryptoSetupTLS,
|
||||
server CryptoSetup,
|
||||
sChunkChan <-chan chunk) (error /* client error */, error /* server error */) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
@ -195,7 +195,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
client, _, err := NewCryptoSetupTLSClient(
|
||||
client, _, err := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
|
@ -211,7 +211,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
|
@ -250,7 +250,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
|
||||
It("signals when it has written the ClientHello", func() {
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
client, chChan, err := NewCryptoSetupTLSClient(
|
||||
client, chChan, err := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
|
@ -289,7 +289,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters
|
||||
cChunkChan, cInitialStream, cHandshakeStream := initStreams()
|
||||
cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second}
|
||||
client, _, err := NewCryptoSetupTLSClient(
|
||||
client, _, err := NewCryptoSetupClient(
|
||||
cInitialStream,
|
||||
cHandshakeStream,
|
||||
protocol.ConnectionID{},
|
||||
|
@ -309,7 +309,7 @@ var _ = Describe("Crypto Setup TLS", func() {
|
|||
IdleTimeout: 0x1337 * time.Second,
|
||||
StatelessResetToken: bytes.Repeat([]byte{42}, 16),
|
||||
}
|
||||
server, err := NewCryptoSetupTLSServer(
|
||||
server, err := NewCryptoSetupServer(
|
||||
sInitialStream,
|
||||
sHandshakeStream,
|
||||
protocol.ConnectionID{},
|
File diff suppressed because one or more lines are too long
|
@ -1,48 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
kexLifetime = protocol.EphermalKeyLifetime
|
||||
kexCurrent crypto.KeyExchange
|
||||
kexCurrentTime time.Time
|
||||
kexMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// getEphermalKEX returns the currently active KEX, which changes every protocol.EphermalKeyLifetime
|
||||
// See the explanation from the QUIC crypto doc:
|
||||
//
|
||||
// A single connection is the usual scope for forward security, but the security
|
||||
// difference between an ephemeral key used for a single connection, and one
|
||||
// used for all connections for 60 seconds is negligible. Thus we can amortise
|
||||
// the Diffie-Hellman key generation at the server over all the connections in a
|
||||
// small time span.
|
||||
func getEphermalKEX() (crypto.KeyExchange, error) {
|
||||
kexMutex.RLock()
|
||||
res := kexCurrent
|
||||
t := kexCurrentTime
|
||||
kexMutex.RUnlock()
|
||||
if res != nil && time.Since(t) < kexLifetime {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
kexMutex.Lock()
|
||||
defer kexMutex.Unlock()
|
||||
// Check if still unfulfilled
|
||||
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
|
||||
kex, err := crypto.NewCurve25519KEX()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kexCurrent = kex
|
||||
kexCurrentTime = time.Now()
|
||||
return kexCurrent, nil
|
||||
}
|
||||
return kexCurrent, nil
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Ephermal KEX", func() {
|
||||
It("has a consistent KEX", func() {
|
||||
kex1, err := getEphermalKEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(kex1).ToNot(BeNil())
|
||||
kex2, err := getEphermalKEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(kex2).ToNot(BeNil())
|
||||
Expect(kex1).To(Equal(kex2))
|
||||
})
|
||||
|
||||
It("changes KEX", func() {
|
||||
kexLifetime = 10 * time.Millisecond
|
||||
defer func() {
|
||||
kexLifetime = protocol.EphermalKeyLifetime
|
||||
}()
|
||||
kex, err := getEphermalKEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(kex).ToNot(BeNil())
|
||||
time.Sleep(kexLifetime)
|
||||
kex2, err := getEphermalKEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(kex2).ToNot(Equal(kex))
|
||||
})
|
||||
})
|
|
@ -1,137 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// A HandshakeMessage is a handshake message
|
||||
type HandshakeMessage struct {
|
||||
Tag Tag
|
||||
Data map[Tag][]byte
|
||||
}
|
||||
|
||||
var _ fmt.Stringer = &HandshakeMessage{}
|
||||
|
||||
// ParseHandshakeMessage reads a crypto message
|
||||
func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) {
|
||||
slice4 := make([]byte, 4)
|
||||
|
||||
if _, err := io.ReadFull(r, slice4); err != nil {
|
||||
return HandshakeMessage{}, err
|
||||
}
|
||||
messageTag := Tag(binary.LittleEndian.Uint32(slice4))
|
||||
|
||||
if _, err := io.ReadFull(r, slice4); err != nil {
|
||||
return HandshakeMessage{}, err
|
||||
}
|
||||
nPairs := binary.LittleEndian.Uint32(slice4)
|
||||
|
||||
if nPairs > protocol.CryptoMaxParams {
|
||||
return HandshakeMessage{}, qerr.CryptoTooManyEntries
|
||||
}
|
||||
|
||||
index := make([]byte, nPairs*8)
|
||||
if _, err := io.ReadFull(r, index); err != nil {
|
||||
return HandshakeMessage{}, err
|
||||
}
|
||||
|
||||
resultMap := map[Tag][]byte{}
|
||||
|
||||
var dataStart uint32
|
||||
for indexPos := 0; indexPos < int(nPairs)*8; indexPos += 8 {
|
||||
tag := Tag(binary.LittleEndian.Uint32(index[indexPos : indexPos+4]))
|
||||
dataEnd := binary.LittleEndian.Uint32(index[indexPos+4 : indexPos+8])
|
||||
|
||||
dataLen := dataEnd - dataStart
|
||||
if dataLen > protocol.CryptoParameterMaxLength {
|
||||
return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long")
|
||||
}
|
||||
|
||||
data := make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(r, data); err != nil {
|
||||
return HandshakeMessage{}, err
|
||||
}
|
||||
|
||||
resultMap[tag] = data
|
||||
dataStart = dataEnd
|
||||
}
|
||||
|
||||
return HandshakeMessage{
|
||||
Tag: messageTag,
|
||||
Data: resultMap}, nil
|
||||
}
|
||||
|
||||
// Write writes a crypto message
|
||||
func (h HandshakeMessage) Write(b *bytes.Buffer) {
|
||||
data := h.Data
|
||||
utils.LittleEndian.WriteUint32(b, uint32(h.Tag))
|
||||
utils.LittleEndian.WriteUint16(b, uint16(len(data)))
|
||||
utils.LittleEndian.WriteUint16(b, 0)
|
||||
|
||||
// Save current position in the buffer, so that we can update the index in-place later
|
||||
indexStart := b.Len()
|
||||
|
||||
indexData := make([]byte, 8*len(data))
|
||||
b.Write(indexData) // Will be updated later
|
||||
|
||||
offset := uint32(0)
|
||||
for i, t := range h.getTagsSorted() {
|
||||
v := data[t]
|
||||
b.Write(v)
|
||||
offset += uint32(len(v))
|
||||
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
|
||||
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
|
||||
}
|
||||
|
||||
// Now we write the index data for real
|
||||
copy(b.Bytes()[indexStart:], indexData)
|
||||
}
|
||||
|
||||
func (h *HandshakeMessage) getTagsSorted() []Tag {
|
||||
tags := make([]Tag, len(h.Data))
|
||||
i := 0
|
||||
for t := range h.Data {
|
||||
tags[i] = t
|
||||
i++
|
||||
}
|
||||
sort.Slice(tags, func(i, j int) bool {
|
||||
return tags[i] < tags[j]
|
||||
})
|
||||
return tags
|
||||
}
|
||||
|
||||
func (h HandshakeMessage) String() string {
|
||||
var pad string
|
||||
res := tagToString(h.Tag) + ":\n"
|
||||
for _, tag := range h.getTagsSorted() {
|
||||
if tag == TagPAD {
|
||||
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
|
||||
} else {
|
||||
res += fmt.Sprintf("\t%s: %#v\n", tagToString(tag), string(h.Data[tag]))
|
||||
}
|
||||
}
|
||||
|
||||
if len(pad) > 0 {
|
||||
res += pad
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func tagToString(tag Tag) string {
|
||||
b := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(b, uint32(tag))
|
||||
for i := range b {
|
||||
if b[i] == 0 {
|
||||
b[i] = ' '
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Handshake Message", func() {
|
||||
Context("when parsing", func() {
|
||||
It("parses sample CHLO message", func() {
|
||||
msg, err := ParseHandshakeMessage(bytes.NewReader(sampleCHLO))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(msg.Tag).To(Equal(TagCHLO))
|
||||
Expect(msg.Data).To(Equal(sampleCHLOMap))
|
||||
})
|
||||
|
||||
It("rejects large numbers of pairs", func() {
|
||||
r := bytes.NewReader([]byte("CHLO\xff\xff\xff\xff"))
|
||||
_, err := ParseHandshakeMessage(r)
|
||||
Expect(err).To(MatchError(qerr.CryptoTooManyEntries))
|
||||
})
|
||||
|
||||
It("rejects too long values", func() {
|
||||
r := bytes.NewReader([]byte{
|
||||
'C', 'H', 'L', 'O',
|
||||
1, 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
0xff, 0xff, 0xff, 0xff,
|
||||
})
|
||||
_, err := ParseHandshakeMessage(r)
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.CryptoInvalidValueLength, "value too long")))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when writing", func() {
|
||||
It("writes sample message", func() {
|
||||
b := &bytes.Buffer{}
|
||||
HandshakeMessage{Tag: TagCHLO, Data: sampleCHLOMap}.Write(b)
|
||||
Expect(b.Bytes()).To(Equal(sampleCHLO))
|
||||
})
|
||||
})
|
||||
|
||||
Context("string representation", func() {
|
||||
It("has a string representation", func() {
|
||||
str := HandshakeMessage{
|
||||
Tag: TagSHLO,
|
||||
Data: map[Tag][]byte{
|
||||
TagAEAD: []byte("foobar"),
|
||||
TagEXPY: []byte("raboof"),
|
||||
},
|
||||
}.String()
|
||||
Expect(str[:4]).To(Equal("SHLO"))
|
||||
Expect(str).To(ContainSubstring("AEAD: \"foobar\""))
|
||||
Expect(str).To(ContainSubstring("EXPY: \"raboof\""))
|
||||
})
|
||||
|
||||
It("lists padding separately", func() {
|
||||
str := HandshakeMessage{
|
||||
Tag: TagSHLO,
|
||||
Data: map[Tag][]byte{
|
||||
TagPAD: bytes.Repeat([]byte{0}, 1337),
|
||||
},
|
||||
}.String()
|
||||
Expect(str).To(ContainSubstring("PAD"))
|
||||
Expect(str).To(ContainSubstring("1337 bytes"))
|
||||
})
|
||||
})
|
||||
})
|
|
@ -25,28 +25,17 @@ type tlsExtensionHandler interface {
|
|||
ReceivedExtensions(msgType uint8, exts []qtls.Extension) error
|
||||
}
|
||||
|
||||
type baseCryptoSetup interface {
|
||||
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||
type CryptoSetup interface {
|
||||
RunHandshake() error
|
||||
io.Closer
|
||||
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
}
|
||||
|
||||
// CryptoSetup is the crypto setup used by gQUIC
|
||||
type CryptoSetup interface {
|
||||
baseCryptoSetup
|
||||
|
||||
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
|
||||
Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, protocol.EncryptionLevel, error)
|
||||
}
|
||||
|
||||
// CryptoSetupTLS is the crypto setup used by IETF QUIC
|
||||
type CryptoSetupTLS interface {
|
||||
baseCryptoSetup
|
||||
|
||||
io.Closer
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error)
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
)
|
||||
|
||||
// ServerConfig is a server config
|
||||
type ServerConfig struct {
|
||||
kex crypto.KeyExchange
|
||||
certChain crypto.CertChain
|
||||
ID []byte
|
||||
obit []byte
|
||||
cookieGenerator *CookieGenerator
|
||||
}
|
||||
|
||||
// NewServerConfig creates a new server config
|
||||
func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
|
||||
id := make([]byte, 16)
|
||||
_, err := rand.Read(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
obit := make([]byte, 8)
|
||||
if _, err = rand.Read(obit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookieGenerator, err := NewCookieGenerator()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ServerConfig{
|
||||
kex: kex,
|
||||
certChain: certChain,
|
||||
ID: id,
|
||||
obit: obit,
|
||||
cookieGenerator: cookieGenerator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get the server config binary representation
|
||||
func (s *ServerConfig) Get() []byte {
|
||||
var serverConfig bytes.Buffer
|
||||
msg := HandshakeMessage{
|
||||
Tag: TagSCFG,
|
||||
Data: map[Tag][]byte{
|
||||
TagSCID: s.ID,
|
||||
TagKEXS: []byte("C255"),
|
||||
TagAEAD: []byte("AESG"),
|
||||
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
|
||||
TagOBIT: s.obit,
|
||||
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
},
|
||||
}
|
||||
msg.Write(&serverConfig)
|
||||
return serverConfig.Bytes()
|
||||
}
|
||||
|
||||
// Sign the server config and CHLO with the server's keyData
|
||||
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
|
||||
return s.certChain.SignServerProof(sni, chlo, s.Get())
|
||||
}
|
||||
|
||||
// GetCertsCompressed returns the certificate data
|
||||
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
|
||||
return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
|
||||
}
|
|
@ -1,184 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
type serverConfigClient struct {
|
||||
raw []byte
|
||||
ID []byte
|
||||
obit []byte
|
||||
expiry time.Time
|
||||
|
||||
kex crypto.KeyExchange
|
||||
sharedSecret []byte
|
||||
}
|
||||
|
||||
var (
|
||||
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
|
||||
)
|
||||
|
||||
// parseServerConfig parses a server config
|
||||
func parseServerConfig(data []byte) (*serverConfigClient, error) {
|
||||
message, err := ParseHandshakeMessage(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if message.Tag != TagSCFG {
|
||||
return nil, errMessageNotServerConfig
|
||||
}
|
||||
|
||||
scfg := &serverConfigClient{raw: data}
|
||||
err = scfg.parseValues(message.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return scfg, nil
|
||||
}
|
||||
|
||||
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
|
||||
// SCID
|
||||
scfgID, ok := tagMap[TagSCID]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
|
||||
}
|
||||
if len(scfgID) != 16 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
|
||||
}
|
||||
s.ID = scfgID
|
||||
|
||||
// KEXS
|
||||
// TODO: setup Key Exchange
|
||||
kexs, ok := tagMap[TagKEXS]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
|
||||
}
|
||||
if len(kexs)%4 != 0 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
|
||||
}
|
||||
c255Foundat := -1
|
||||
|
||||
for i := 0; i < len(kexs)/4; i++ {
|
||||
if bytes.Equal(kexs[4*i:4*i+4], []byte("C255")) {
|
||||
c255Foundat = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if c255Foundat < 0 {
|
||||
return qerr.Error(qerr.CryptoNoSupport, "KEXS: Could not find C255, other key exchanges are not supported")
|
||||
}
|
||||
|
||||
// AEAD
|
||||
aead, ok := tagMap[TagAEAD]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
|
||||
}
|
||||
if len(aead)%4 != 0 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
|
||||
}
|
||||
var aesgFound bool
|
||||
for i := 0; i < len(aead)/4; i++ {
|
||||
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
|
||||
aesgFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !aesgFound {
|
||||
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
|
||||
}
|
||||
|
||||
// PUBS
|
||||
pubs, ok := tagMap[TagPUBS]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
|
||||
}
|
||||
|
||||
var pubsKexs []struct {
|
||||
Length uint32
|
||||
Value []byte
|
||||
}
|
||||
var lastLen uint32
|
||||
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
|
||||
// the PUBS value is always prepended by 3 byte little endian length field
|
||||
|
||||
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
|
||||
if err != nil {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
|
||||
}
|
||||
if lastLen == 0 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
if i+3+int(lastLen) > len(pubs) {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
pubsKexs = append(pubsKexs, struct {
|
||||
Length uint32
|
||||
Value []byte
|
||||
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
|
||||
}
|
||||
|
||||
if c255Foundat >= len(pubsKexs) {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
|
||||
}
|
||||
|
||||
if pubsKexs[c255Foundat].Length != 32 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
|
||||
}
|
||||
|
||||
var err error
|
||||
s.kex, err = crypto.NewCurve25519KEX()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// OBIT
|
||||
obit, ok := tagMap[TagOBIT]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
|
||||
}
|
||||
if len(obit) != 8 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
|
||||
}
|
||||
s.obit = obit
|
||||
|
||||
// EXPY
|
||||
expy, ok := tagMap[TagEXPY]
|
||||
if !ok {
|
||||
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
|
||||
}
|
||||
if len(expy) != 8 {
|
||||
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
|
||||
}
|
||||
// make sure that the value doesn't overflow an int64
|
||||
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
|
||||
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
|
||||
s.expiry = time.Unix(int64(expyTimestamp), 0)
|
||||
|
||||
// TODO: implement VER
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serverConfigClient) IsExpired() bool {
|
||||
return s.expiry.Before(time.Now())
|
||||
}
|
||||
|
||||
func (s *serverConfigClient) Get() []byte {
|
||||
return s.raw
|
||||
}
|
|
@ -1,266 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// This tagMap can be passed to parseValues and is garantueed to not cause any errors
|
||||
func getDefaultServerConfigClient() map[Tag][]byte {
|
||||
return map[Tag][]byte{
|
||||
TagSCID: bytes.Repeat([]byte{'F'}, 16),
|
||||
TagKEXS: []byte("C255"),
|
||||
TagAEAD: []byte("AESG"),
|
||||
TagPUBS: append([]byte{0x20, 0x00, 0x00}, bytes.Repeat([]byte{0}, 32)...),
|
||||
TagOBIT: bytes.Repeat([]byte{0}, 8),
|
||||
TagEXPY: {0x0, 0x6c, 0x57, 0x78, 0, 0, 0, 0}, // 2033-12-24
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("Server Config", func() {
|
||||
var tagMap map[Tag][]byte
|
||||
|
||||
BeforeEach(func() {
|
||||
tagMap = getDefaultServerConfigClient()
|
||||
})
|
||||
|
||||
It("returns the parsed server config", func() {
|
||||
tagMap[TagSCID] = []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}
|
||||
b := &bytes.Buffer{}
|
||||
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b)
|
||||
scfg, err := parseServerConfig(b.Bytes())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg.ID).To(Equal(tagMap[TagSCID]))
|
||||
})
|
||||
|
||||
It("saves the raw server config", func() {
|
||||
b := &bytes.Buffer{}
|
||||
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b)
|
||||
scfg, err := parseServerConfig(b.Bytes())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg.raw).To(Equal(b.Bytes()))
|
||||
})
|
||||
|
||||
It("tells if a server config is expired", func() {
|
||||
scfg := &serverConfigClient{}
|
||||
scfg.expiry = time.Now().Add(-time.Second)
|
||||
Expect(scfg.IsExpired()).To(BeTrue())
|
||||
scfg.expiry = time.Now().Add(time.Second)
|
||||
Expect(scfg.IsExpired()).To(BeFalse())
|
||||
})
|
||||
|
||||
Context("parsing the server config", func() {
|
||||
It("rejects a handshake message with the wrong message tag", func() {
|
||||
var serverConfig bytes.Buffer
|
||||
HandshakeMessage{Tag: TagCHLO, Data: make(map[Tag][]byte)}.Write(&serverConfig)
|
||||
_, err := parseServerConfig(serverConfig.Bytes())
|
||||
Expect(err).To(MatchError(errMessageNotServerConfig))
|
||||
})
|
||||
|
||||
It("errors on invalid handshake messages", func() {
|
||||
var serverConfig bytes.Buffer
|
||||
HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig)
|
||||
_, err := parseServerConfig(serverConfig.Bytes()[:serverConfig.Len()-2])
|
||||
Expect(err).To(MatchError("unexpected EOF"))
|
||||
})
|
||||
|
||||
It("passes on errors encountered when reading the TagMap", func() {
|
||||
var serverConfig bytes.Buffer
|
||||
HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig)
|
||||
_, err := parseServerConfig(serverConfig.Bytes())
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID"))
|
||||
})
|
||||
|
||||
It("reads an example Handshake Message", func() {
|
||||
var serverConfig bytes.Buffer
|
||||
HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(&serverConfig)
|
||||
scfg, err := parseServerConfig(serverConfig.Bytes())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg.ID).To(Equal(tagMap[TagSCID]))
|
||||
Expect(scfg.obit).To(Equal(tagMap[TagOBIT]))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Reading values from the TagMap", func() {
|
||||
var scfg *serverConfigClient
|
||||
|
||||
BeforeEach(func() {
|
||||
scfg = &serverConfigClient{}
|
||||
})
|
||||
|
||||
Context("ServerConfig ID", func() {
|
||||
It("parses the ServerConfig ID", func() {
|
||||
id := []byte{0xb2, 0xa4, 0xbb, 0x8f, 0xf6, 0x51, 0x28, 0xfd, 0x4d, 0xf7, 0xb3, 0x9a, 0x91, 0xe7, 0x91, 0xfb}
|
||||
tagMap[TagSCID] = id
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg.ID).To(Equal(id))
|
||||
})
|
||||
|
||||
It("errors if the ServerConfig ID is missing", func() {
|
||||
delete(tagMap, TagSCID)
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID"))
|
||||
})
|
||||
|
||||
It("rejects ServerConfig IDs that have the wrong length", func() {
|
||||
tagMap[TagSCID] = bytes.Repeat([]byte{'F'}, 17) // 1 byte too long
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoInvalidValueLength: SCID"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("KEXS", func() {
|
||||
It("rejects KEXS values that have the wrong length", func() {
|
||||
tagMap[TagKEXS] = bytes.Repeat([]byte{'F'}, 5) // 1 byte too long
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoInvalidValueLength: KEXS"))
|
||||
})
|
||||
|
||||
It("rejects KEXS values other than C255", func() {
|
||||
tagMap[TagKEXS] = []byte("P256")
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoNoSupport: KEXS: Could not find C255, other key exchanges are not supported"))
|
||||
})
|
||||
|
||||
It("errors if the KEXS is missing", func() {
|
||||
delete(tagMap, TagKEXS)
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: KEXS"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("AEAD", func() {
|
||||
It("rejects AEAD values that have the wrong length", func() {
|
||||
tagMap[TagAEAD] = bytes.Repeat([]byte{'F'}, 5) // 1 byte too long
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoInvalidValueLength: AEAD"))
|
||||
})
|
||||
|
||||
It("rejects AEAD values other than AESG", func() {
|
||||
tagMap[TagAEAD] = []byte("S20P")
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoNoSupport: AEAD"))
|
||||
})
|
||||
|
||||
It("recognizes AESG in the list of AEADs, at the first position", func() {
|
||||
tagMap[TagAEAD] = []byte("AESGS20P")
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("recognizes AESG in the list of AEADs, not at the first position", func() {
|
||||
tagMap[TagAEAD] = []byte("S20PAESG")
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("errors if the AEAD is missing", func() {
|
||||
delete(tagMap, TagAEAD)
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: AEAD"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("PUBS", func() {
|
||||
It("creates a Curve25519 key exchange", func() {
|
||||
serverKex, err := crypto.NewCurve25519KEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tagMap[TagPUBS] = append([]byte{0x20, 0x00, 0x00}, serverKex.PublicKey()...)
|
||||
err = scfg.parseValues(tagMap)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sharedSecret, err := serverKex.CalculateSharedKey(scfg.kex.PublicKey())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg.sharedSecret).To(Equal(sharedSecret))
|
||||
})
|
||||
|
||||
It("rejects PUBS values that have the wrong length", func() {
|
||||
tagMap[TagPUBS] = bytes.Repeat([]byte{'F'}, 100) // completely wrong length
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoInvalidValueLength: PUBS"))
|
||||
})
|
||||
|
||||
It("rejects PUBS values that have a zero length", func() {
|
||||
tagMap[TagPUBS] = bytes.Repeat([]byte{0}, 100) // completely wrong length
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoInvalidValueLength: PUBS"))
|
||||
})
|
||||
|
||||
It("ensure that C255 Pubs must not be at the first index", func() {
|
||||
serverKex, err := crypto.NewCurve25519KEX()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tagMap[TagKEXS] = []byte("P256C255") // have another KEXS before C255
|
||||
// 3 byte len + 1 byte empty + C255
|
||||
tagMap[TagPUBS] = append([]byte{0x01, 0x00, 0x00, 0x00}, append([]byte{0x20, 0x00, 0x00}, serverKex.PublicKey()...)...)
|
||||
err = scfg.parseValues(tagMap)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
sharedSecret, err := serverKex.CalculateSharedKey(scfg.kex.PublicKey())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg.sharedSecret).To(Equal(sharedSecret))
|
||||
})
|
||||
|
||||
It("errors if the PUBS is missing", func() {
|
||||
delete(tagMap, TagPUBS)
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: PUBS"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("OBIT", func() {
|
||||
It("parses the OBIT value", func() {
|
||||
obit := []byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}
|
||||
tagMap[TagOBIT] = obit
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg.obit).To(Equal(obit))
|
||||
})
|
||||
|
||||
It("errors if the OBIT is missing", func() {
|
||||
delete(tagMap, TagOBIT)
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: OBIT"))
|
||||
})
|
||||
|
||||
It("rejets OBIT values that have the wrong length", func() {
|
||||
tagMap[TagOBIT] = bytes.Repeat([]byte{'F'}, 7) // 1 byte too short
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoInvalidValueLength: OBIT"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("EXPY", func() {
|
||||
It("parses the expiry date", func() {
|
||||
tagMap[TagEXPY] = []byte{0xdc, 0x89, 0x0e, 0x59, 0, 0, 0, 0} // UNIX Timestamp 0x590e89dc = 1494125020
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
year, month, day := scfg.expiry.UTC().Date()
|
||||
Expect(year).To(Equal(2017))
|
||||
Expect(month).To(Equal(time.Month(5)))
|
||||
Expect(day).To(Equal(7))
|
||||
})
|
||||
|
||||
It("errors if the EXPY is missing", func() {
|
||||
delete(tagMap, TagEXPY)
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoMessageParameterNotFound: EXPY"))
|
||||
})
|
||||
|
||||
It("rejects EXPY values that have the wrong length", func() {
|
||||
tagMap[TagEXPY] = bytes.Repeat([]byte{'F'}, 9) // 1 byte too long
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).To(MatchError("CryptoInvalidValueLength: EXPY"))
|
||||
})
|
||||
|
||||
It("deals with absurdly large timestamps", func() {
|
||||
tagMap[TagEXPY] = bytes.Repeat([]byte{0xff}, 8) // this would overflow the int64
|
||||
err := scfg.parseValues(tagMap)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg.expiry.After(time.Now())).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
|
@ -1,45 +0,0 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/crypto"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ServerConfig", func() {
|
||||
var (
|
||||
kex crypto.KeyExchange
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
kex, err = crypto.NewCurve25519KEX()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
})
|
||||
|
||||
It("generates a random ID and OBIT", func() {
|
||||
scfg1, err := NewServerConfig(kex, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg2, err := NewServerConfig(kex, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scfg1.ID).ToNot(Equal(scfg2.ID))
|
||||
Expect(scfg1.obit).ToNot(Equal(scfg2.obit))
|
||||
Expect(scfg1.cookieGenerator).ToNot(Equal(scfg2.cookieGenerator))
|
||||
})
|
||||
|
||||
It("gets the proper binary representation", func() {
|
||||
scfg, err := NewServerConfig(kex, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
expected := bytes.NewBuffer([]byte{0x53, 0x43, 0x46, 0x47, 0x6, 0x0, 0x0, 0x0, 0x41, 0x45, 0x41, 0x44, 0x4, 0x0, 0x0, 0x0, 0x53, 0x43, 0x49, 0x44, 0x14, 0x0, 0x0, 0x0, 0x50, 0x55, 0x42, 0x53, 0x37, 0x0, 0x0, 0x0, 0x4b, 0x45, 0x58, 0x53, 0x3b, 0x0, 0x0, 0x0, 0x4f, 0x42, 0x49, 0x54, 0x43, 0x0, 0x0, 0x0, 0x45, 0x58, 0x50, 0x59, 0x4b, 0x0, 0x0, 0x0, 0x41, 0x45, 0x53, 0x47})
|
||||
expected.Write(scfg.ID)
|
||||
expected.Write([]byte{0x20, 0x0, 0x0})
|
||||
expected.Write(kex.PublicKey())
|
||||
expected.Write([]byte{0x43, 0x32, 0x35, 0x35})
|
||||
expected.Write(scfg.obit)
|
||||
expected.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
|
||||
Expect(scfg.Get()).To(Equal(expected.Bytes()))
|
||||
})
|
||||
})
|
|
@ -1,93 +0,0 @@
|
|||
package handshake
|
||||
|
||||
// A Tag in the QUIC crypto
|
||||
type Tag uint32
|
||||
|
||||
const (
|
||||
// TagCHLO is a client hello
|
||||
TagCHLO Tag = 'C' + 'H'<<8 + 'L'<<16 + 'O'<<24
|
||||
// TagREJ is a server hello rejection
|
||||
TagREJ Tag = 'R' + 'E'<<8 + 'J'<<16
|
||||
// TagSCFG is a server config
|
||||
TagSCFG Tag = 'S' + 'C'<<8 + 'F'<<16 + 'G'<<24
|
||||
|
||||
// TagPAD is padding
|
||||
TagPAD Tag = 'P' + 'A'<<8 + 'D'<<16
|
||||
// TagSNI is the server name indication
|
||||
TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16
|
||||
// TagVER is the QUIC version
|
||||
TagVER Tag = 'V' + 'E'<<8 + 'R'<<16
|
||||
// TagCCS are the hashes of the common certificate sets
|
||||
TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16
|
||||
// TagCCRT are the hashes of the cached certificates
|
||||
TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24
|
||||
// TagMSPC is max streams per connection
|
||||
TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24
|
||||
// TagMIDS is max incoming dyanamic streams
|
||||
TagMIDS Tag = 'M' + 'I'<<8 + 'D'<<16 + 'S'<<24
|
||||
// TagUAID is the user agent ID
|
||||
TagUAID Tag = 'U' + 'A'<<8 + 'I'<<16 + 'D'<<24
|
||||
// TagSVID is the server ID (unofficial tag by us :)
|
||||
TagSVID Tag = 'S' + 'V'<<8 + 'I'<<16 + 'D'<<24
|
||||
// TagTCID is truncation of the connection ID
|
||||
TagTCID Tag = 'T' + 'C'<<8 + 'I'<<16 + 'D'<<24
|
||||
// TagPDMD is the proof demand
|
||||
TagPDMD Tag = 'P' + 'D'<<8 + 'M'<<16 + 'D'<<24
|
||||
// TagSRBF is the socket receive buffer
|
||||
TagSRBF Tag = 'S' + 'R'<<8 + 'B'<<16 + 'F'<<24
|
||||
// TagICSL is the idle connection state lifetime
|
||||
TagICSL Tag = 'I' + 'C'<<8 + 'S'<<16 + 'L'<<24
|
||||
// TagNONP is the client proof nonce
|
||||
TagNONP Tag = 'N' + 'O'<<8 + 'N'<<16 + 'P'<<24
|
||||
// TagSCLS is the silently close timeout
|
||||
TagSCLS Tag = 'S' + 'C'<<8 + 'L'<<16 + 'S'<<24
|
||||
// TagCSCT is the signed cert timestamp (RFC6962) of leaf cert
|
||||
TagCSCT Tag = 'C' + 'S'<<8 + 'C'<<16 + 'T'<<24
|
||||
// TagCOPT are the connection options
|
||||
TagCOPT Tag = 'C' + 'O'<<8 + 'P'<<16 + 'T'<<24
|
||||
// TagCFCW is the initial session/connection flow control receive window
|
||||
TagCFCW Tag = 'C' + 'F'<<8 + 'C'<<16 + 'W'<<24
|
||||
// TagSFCW is the initial stream flow control receive window.
|
||||
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
|
||||
|
||||
// TagNSTP is the no STOP_WAITING experiment
|
||||
// currently unsupported by quic-go
|
||||
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
|
||||
|
||||
// TagSTK is the source-address token
|
||||
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
|
||||
// TagSNO is the server nonce
|
||||
TagSNO Tag = 'S' + 'N'<<8 + 'O'<<16
|
||||
// TagPROF is the server proof
|
||||
TagPROF Tag = 'P' + 'R'<<8 + 'O'<<16 + 'F'<<24
|
||||
|
||||
// TagNONC is the client nonce
|
||||
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
|
||||
// TagXLCT is the expected leaf certificate
|
||||
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
|
||||
|
||||
// TagSCID is the server config ID
|
||||
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
|
||||
// TagKEXS is the list of key exchange algos
|
||||
TagKEXS Tag = 'K' + 'E'<<8 + 'X'<<16 + 'S'<<24
|
||||
// TagAEAD is the list of AEAD algos
|
||||
TagAEAD Tag = 'A' + 'E'<<8 + 'A'<<16 + 'D'<<24
|
||||
// TagPUBS is the public value for the KEX
|
||||
TagPUBS Tag = 'P' + 'U'<<8 + 'B'<<16 + 'S'<<24
|
||||
// TagOBIT is the client orbit
|
||||
TagOBIT Tag = 'O' + 'B'<<8 + 'I'<<16 + 'T'<<24
|
||||
// TagEXPY is the server config expiry
|
||||
TagEXPY Tag = 'E' + 'X'<<8 + 'P'<<16 + 'Y'<<24
|
||||
// TagCERT is the CERT data
|
||||
TagCERT Tag = 0xff545243
|
||||
|
||||
// TagSHLO is the server hello
|
||||
TagSHLO Tag = 'S' + 'H'<<8 + 'L'<<16 + 'O'<<24
|
||||
|
||||
// TagPRST is the public reset tag
|
||||
TagPRST Tag = 'P' + 'R'<<8 + 'S'<<16 + 'T'<<24
|
||||
// TagRSEQ is the public reset rejected packet number
|
||||
TagRSEQ Tag = 'R' + 'S'<<8 + 'E'<<16 + 'Q'<<24
|
||||
// TagRNON is the public reset nonce
|
||||
TagRNON Tag = 'R' + 'N'<<8 + 'O'<<16 + 'N'<<24
|
||||
)
|
|
@ -17,14 +17,15 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
|
|||
handler *extensionHandlerClient
|
||||
paramsChan <-chan TransportParameters
|
||||
)
|
||||
version := protocol.VersionNumber(0x42)
|
||||
|
||||
BeforeEach(func() {
|
||||
var h tlsExtensionHandler
|
||||
h, paramsChan = newExtensionHandlerClient(
|
||||
&TransportParameters{},
|
||||
protocol.VersionWhatever,
|
||||
version,
|
||||
nil,
|
||||
protocol.VersionWhatever,
|
||||
version,
|
||||
utils.DefaultLogger,
|
||||
)
|
||||
handler = h.(*extensionHandlerClient)
|
||||
|
@ -57,6 +58,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() {
|
|||
Type: quicTLSExtensionType,
|
||||
Data: (&encryptedExtensionsTransportParameters{
|
||||
Parameters: params,
|
||||
NegotiatedVersion: version,
|
||||
SupportedVersions: []protocol.VersionNumber{handler.version},
|
||||
}).Marshal(),
|
||||
}
|
||||
|
|
|
@ -12,260 +12,164 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("Transport Parameters", func() {
|
||||
Context("for gQUIC", func() {
|
||||
Context("parsing", func() {
|
||||
It("sets all values", func() {
|
||||
values := map[Tag][]byte{
|
||||
TagSFCW: {0xad, 0xfb, 0xca, 0xde},
|
||||
TagCFCW: {0xef, 0xbe, 0xad, 0xde},
|
||||
TagICSL: {0x0d, 0xf0, 0xad, 0xba},
|
||||
TagMIDS: {0xff, 0x10, 0x00, 0xc0},
|
||||
}
|
||||
params, err := readHelloMap(values)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0xdecafbad)))
|
||||
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0xdeadbeef)))
|
||||
Expect(params.IdleTimeout).To(Equal(time.Duration(0xbaadf00d) * time.Second))
|
||||
Expect(params.MaxStreams).To(Equal(uint32(0xc00010ff)))
|
||||
Expect(params.OmitConnectionID).To(BeFalse())
|
||||
})
|
||||
It("has a string representation", func() {
|
||||
p := &TransportParameters{
|
||||
StreamFlowControlWindow: 0x1234,
|
||||
ConnectionFlowControlWindow: 0x4321,
|
||||
MaxBidiStreams: 1337,
|
||||
MaxUniStreams: 7331,
|
||||
IdleTimeout: 42 * time.Second,
|
||||
}
|
||||
Expect(p.String()).To(Equal("&handshake.TransportParameters{StreamFlowControlWindow: 0x1234, ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}"))
|
||||
})
|
||||
|
||||
It("reads if the connection ID should be omitted", func() {
|
||||
values := map[Tag][]byte{TagTCID: {0, 0, 0, 0}}
|
||||
params, err := readHelloMap(values)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(params.OmitConnectionID).To(BeTrue())
|
||||
})
|
||||
Context("parsing", func() {
|
||||
var (
|
||||
params *TransportParameters
|
||||
parameters map[transportParameterID][]byte
|
||||
statelessResetToken []byte
|
||||
)
|
||||
|
||||
It("doesn't allow idle timeouts below the minimum remote idle timeout", func() {
|
||||
t := 2 * time.Second
|
||||
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
||||
values := map[Tag][]byte{
|
||||
TagICSL: {uint8(t.Seconds()), 0, 0, 0},
|
||||
}
|
||||
params, err := readHelloMap(values)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
|
||||
})
|
||||
marshal := func(p map[transportParameterID][]byte) []byte {
|
||||
b := &bytes.Buffer{}
|
||||
for id, val := range p {
|
||||
utils.BigEndian.WriteUint16(b, uint16(id))
|
||||
utils.BigEndian.WriteUint16(b, uint16(len(val)))
|
||||
b.Write(val)
|
||||
}
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
It("errors when given an invalid SFCW value", func() {
|
||||
values := map[Tag][]byte{TagSFCW: {2, 0, 0}} // 1 byte too short
|
||||
_, err := readHelloMap(values)
|
||||
Expect(err).To(MatchError(errMalformedTag))
|
||||
})
|
||||
|
||||
It("errors when given an invalid CFCW value", func() {
|
||||
values := map[Tag][]byte{TagCFCW: {2, 0, 0}} // 1 byte too short
|
||||
_, err := readHelloMap(values)
|
||||
Expect(err).To(MatchError(errMalformedTag))
|
||||
})
|
||||
|
||||
It("errors when given an invalid TCID value", func() {
|
||||
values := map[Tag][]byte{TagTCID: {2, 0, 0}} // 1 byte too short
|
||||
_, err := readHelloMap(values)
|
||||
Expect(err).To(MatchError(errMalformedTag))
|
||||
})
|
||||
|
||||
It("errors when given an invalid ICSL value", func() {
|
||||
values := map[Tag][]byte{TagICSL: {2, 0, 0}} // 1 byte too short
|
||||
_, err := readHelloMap(values)
|
||||
Expect(err).To(MatchError(errMalformedTag))
|
||||
})
|
||||
|
||||
It("errors when given an invalid MIDS value", func() {
|
||||
values := map[Tag][]byte{TagMIDS: {2, 0, 0}} // 1 byte too short
|
||||
_, err := readHelloMap(values)
|
||||
Expect(err).To(MatchError(errMalformedTag))
|
||||
})
|
||||
BeforeEach(func() {
|
||||
params = &TransportParameters{}
|
||||
statelessResetToken = bytes.Repeat([]byte{42}, 16)
|
||||
parameters = map[transportParameterID][]byte{
|
||||
initialMaxStreamDataParameterID: {0x11, 0x22, 0x33, 0x44},
|
||||
initialMaxDataParameterID: {0x22, 0x33, 0x44, 0x55},
|
||||
initialMaxBidiStreamsParameterID: {0x33, 0x44},
|
||||
initialMaxUniStreamsParameterID: {0x44, 0x55},
|
||||
idleTimeoutParameterID: {0x13, 0x37},
|
||||
maxPacketSizeParameterID: {0x73, 0x31},
|
||||
disableMigrationParameterID: {},
|
||||
statelessResetTokenParameterID: statelessResetToken,
|
||||
}
|
||||
})
|
||||
It("reads parameters", func() {
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344)))
|
||||
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455)))
|
||||
Expect(params.MaxBidiStreams).To(Equal(uint16(0x3344)))
|
||||
Expect(params.MaxUniStreams).To(Equal(uint16(0x4455)))
|
||||
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
||||
Expect(params.MaxPacketSize).To(Equal(protocol.ByteCount(0x7331)))
|
||||
Expect(params.DisableMigration).To(BeTrue())
|
||||
Expect(params.StatelessResetToken).To(Equal(statelessResetToken))
|
||||
})
|
||||
|
||||
Context("writing", func() {
|
||||
It("returns all necessary parameters ", func() {
|
||||
params := &TransportParameters{
|
||||
StreamFlowControlWindow: 0xdeadbeef,
|
||||
ConnectionFlowControlWindow: 0xdecafbad,
|
||||
IdleTimeout: 0xbaaaaaad * time.Second,
|
||||
MaxStreams: 0x1337,
|
||||
}
|
||||
entryMap := params.getHelloMap()
|
||||
Expect(entryMap).To(HaveLen(4))
|
||||
Expect(entryMap).ToNot(HaveKey(TagTCID))
|
||||
Expect(entryMap).To(HaveKeyWithValue(TagSFCW, []byte{0xef, 0xbe, 0xad, 0xde}))
|
||||
Expect(entryMap).To(HaveKeyWithValue(TagCFCW, []byte{0xad, 0xfb, 0xca, 0xde}))
|
||||
Expect(entryMap).To(HaveKeyWithValue(TagICSL, []byte{0xad, 0xaa, 0xaa, 0xba}))
|
||||
Expect(entryMap).To(HaveKeyWithValue(TagMIDS, []byte{0x37, 0x13, 0, 0}))
|
||||
})
|
||||
It("errors if a parameter is sent twice", func() {
|
||||
data := marshal(parameters)
|
||||
parameters = map[transportParameterID][]byte{
|
||||
maxPacketSizeParameterID: {0x73, 0x31},
|
||||
}
|
||||
data = append(data, marshal(parameters)...)
|
||||
err := params.unmarshal(data)
|
||||
Expect(err).To(MatchError(fmt.Sprintf("received duplicate transport parameter %#x", maxPacketSizeParameterID)))
|
||||
})
|
||||
|
||||
It("requests omission of the connection ID", func() {
|
||||
params := &TransportParameters{OmitConnectionID: true}
|
||||
entryMap := params.getHelloMap()
|
||||
Expect(entryMap).To(HaveKeyWithValue(TagTCID, []byte{0, 0, 0, 0}))
|
||||
})
|
||||
It("doesn't allow values below the minimum remote idle timeout", func() {
|
||||
t := 2 * time.Second
|
||||
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
||||
parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
|
||||
parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_max_data has the wrong length", func() {
|
||||
parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
|
||||
parameters[initialMaxBidiStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for initial_max_stream_id_bidi: 3 (expected 2)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
|
||||
parameters[initialMaxUniStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for initial_max_stream_id_uni: 3 (expected 2)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
|
||||
parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if max_packet_size has the wrong length", func() {
|
||||
parameters[maxPacketSizeParameterID] = []byte{0x11} // should be 2 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for max_packet_size: 1 (expected 2)"))
|
||||
})
|
||||
|
||||
It("rejects max_packet_sizes smaller than 1200 bytes", func() {
|
||||
parameters[maxPacketSizeParameterID] = []byte{0x4, 0xaf} // 0x4af = 1199
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("invalid value for max_packet_size: 1199 (minimum 1200)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if disable_connection_migration has the wrong length", func() {
|
||||
parameters[disableMigrationParameterID] = []byte{0x11} // should empty
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for disable_migration: 1 (expected empty)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the stateless_reset_token has the wrong length", func() {
|
||||
parameters[statelessResetTokenParameterID] = statelessResetToken[1:]
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)"))
|
||||
})
|
||||
|
||||
It("ignores unknown parameters", func() {
|
||||
parameters[1337] = []byte{42}
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("for TLS", func() {
|
||||
It("has a string representation", func() {
|
||||
p := &TransportParameters{
|
||||
StreamFlowControlWindow: 0x1234,
|
||||
ConnectionFlowControlWindow: 0x4321,
|
||||
MaxBidiStreams: 1337,
|
||||
MaxUniStreams: 7331,
|
||||
IdleTimeout: 42 * time.Second,
|
||||
Context("marshalling", func() {
|
||||
It("marshals", func() {
|
||||
params := &TransportParameters{
|
||||
StreamFlowControlWindow: 0xdeadbeef,
|
||||
ConnectionFlowControlWindow: 0xdecafbad,
|
||||
IdleTimeout: 0xcafe * time.Second,
|
||||
MaxBidiStreams: 0x1234,
|
||||
MaxUniStreams: 0x4321,
|
||||
DisableMigration: true,
|
||||
StatelessResetToken: bytes.Repeat([]byte{100}, 16),
|
||||
}
|
||||
Expect(p.String()).To(Equal("&handshake.TransportParameters{StreamFlowControlWindow: 0x1234, ConnectionFlowControlWindow: 0x4321, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s}"))
|
||||
})
|
||||
b := &bytes.Buffer{}
|
||||
params.marshal(b)
|
||||
|
||||
Context("parsing", func() {
|
||||
var (
|
||||
params *TransportParameters
|
||||
parameters map[transportParameterID][]byte
|
||||
statelessResetToken []byte
|
||||
)
|
||||
|
||||
marshal := func(p map[transportParameterID][]byte) []byte {
|
||||
b := &bytes.Buffer{}
|
||||
for id, val := range p {
|
||||
utils.BigEndian.WriteUint16(b, uint16(id))
|
||||
utils.BigEndian.WriteUint16(b, uint16(len(val)))
|
||||
b.Write(val)
|
||||
}
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
params = &TransportParameters{}
|
||||
statelessResetToken = bytes.Repeat([]byte{42}, 16)
|
||||
parameters = map[transportParameterID][]byte{
|
||||
initialMaxStreamDataParameterID: {0x11, 0x22, 0x33, 0x44},
|
||||
initialMaxDataParameterID: {0x22, 0x33, 0x44, 0x55},
|
||||
initialMaxBidiStreamsParameterID: {0x33, 0x44},
|
||||
initialMaxUniStreamsParameterID: {0x44, 0x55},
|
||||
idleTimeoutParameterID: {0x13, 0x37},
|
||||
maxPacketSizeParameterID: {0x73, 0x31},
|
||||
disableMigrationParameterID: {},
|
||||
statelessResetTokenParameterID: statelessResetToken,
|
||||
}
|
||||
})
|
||||
|
||||
It("reads parameters", func() {
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(params.StreamFlowControlWindow).To(Equal(protocol.ByteCount(0x11223344)))
|
||||
Expect(params.ConnectionFlowControlWindow).To(Equal(protocol.ByteCount(0x22334455)))
|
||||
Expect(params.MaxBidiStreams).To(Equal(uint16(0x3344)))
|
||||
Expect(params.MaxUniStreams).To(Equal(uint16(0x4455)))
|
||||
Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second))
|
||||
Expect(params.OmitConnectionID).To(BeFalse())
|
||||
Expect(params.MaxPacketSize).To(Equal(protocol.ByteCount(0x7331)))
|
||||
Expect(params.DisableMigration).To(BeTrue())
|
||||
Expect(params.StatelessResetToken).To(Equal(statelessResetToken))
|
||||
})
|
||||
|
||||
It("errors if a parameter is sent twice", func() {
|
||||
data := marshal(parameters)
|
||||
parameters = map[transportParameterID][]byte{
|
||||
maxPacketSizeParameterID: {0x73, 0x31},
|
||||
}
|
||||
data = append(data, marshal(parameters)...)
|
||||
err := params.unmarshal(data)
|
||||
Expect(err).To(MatchError(fmt.Sprintf("received duplicate transport parameter %#x", maxPacketSizeParameterID)))
|
||||
})
|
||||
|
||||
It("doesn't allow values below the minimum remote idle timeout", func() {
|
||||
t := 2 * time.Second
|
||||
Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout))
|
||||
parameters[idleTimeoutParameterID] = []byte{0, uint8(t.Seconds())}
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(params.IdleTimeout).To(Equal(protocol.MinRemoteIdleTimeout))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_max_stream_data has the wrong length", func() {
|
||||
parameters[initialMaxStreamDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for initial_max_stream_data: 3 (expected 4)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_max_data has the wrong length", func() {
|
||||
parameters[initialMaxDataParameterID] = []byte{0x11, 0x22, 0x33} // should be 4 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for initial_max_data: 3 (expected 4)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
|
||||
parameters[initialMaxBidiStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for initial_max_stream_id_bidi: 3 (expected 2)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_max_stream_id_bidi has the wrong length", func() {
|
||||
parameters[initialMaxUniStreamsParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for initial_max_stream_id_uni: 3 (expected 2)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the initial_idle_timeout has the wrong length", func() {
|
||||
parameters[idleTimeoutParameterID] = []byte{0x11, 0x22, 0x33} // should be 2 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for idle_timeout: 3 (expected 2)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if max_packet_size has the wrong length", func() {
|
||||
parameters[maxPacketSizeParameterID] = []byte{0x11} // should be 2 bytes
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for max_packet_size: 1 (expected 2)"))
|
||||
})
|
||||
|
||||
It("rejects max_packet_sizes smaller than 1200 bytes", func() {
|
||||
parameters[maxPacketSizeParameterID] = []byte{0x4, 0xaf} // 0x4af = 1199
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("invalid value for max_packet_size: 1199 (minimum 1200)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if disable_connection_migration has the wrong length", func() {
|
||||
parameters[disableMigrationParameterID] = []byte{0x11} // should empty
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for disable_migration: 1 (expected empty)"))
|
||||
})
|
||||
|
||||
It("rejects the parameters if the stateless_reset_token has the wrong length", func() {
|
||||
parameters[statelessResetTokenParameterID] = statelessResetToken[1:]
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)"))
|
||||
})
|
||||
|
||||
It("ignores unknown parameters", func() {
|
||||
parameters[1337] = []byte{42}
|
||||
err := params.unmarshal(marshal(parameters))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("marshalling", func() {
|
||||
It("marshals", func() {
|
||||
params := &TransportParameters{
|
||||
StreamFlowControlWindow: 0xdeadbeef,
|
||||
ConnectionFlowControlWindow: 0xdecafbad,
|
||||
IdleTimeout: 0xcafe * time.Second,
|
||||
MaxBidiStreams: 0x1234,
|
||||
MaxUniStreams: 0x4321,
|
||||
DisableMigration: true,
|
||||
StatelessResetToken: bytes.Repeat([]byte{100}, 16),
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
params.marshal(b)
|
||||
|
||||
p := &TransportParameters{}
|
||||
Expect(p.unmarshal(b.Bytes())).To(Succeed())
|
||||
Expect(p.StreamFlowControlWindow).To(Equal(params.StreamFlowControlWindow))
|
||||
Expect(p.ConnectionFlowControlWindow).To(Equal(params.ConnectionFlowControlWindow))
|
||||
Expect(p.MaxUniStreams).To(Equal(params.MaxUniStreams))
|
||||
Expect(p.MaxBidiStreams).To(Equal(params.MaxBidiStreams))
|
||||
Expect(p.IdleTimeout).To(Equal(params.IdleTimeout))
|
||||
Expect(p.DisableMigration).To(Equal(params.DisableMigration))
|
||||
Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken))
|
||||
})
|
||||
p := &TransportParameters{}
|
||||
Expect(p.unmarshal(b.Bytes())).To(Succeed())
|
||||
Expect(p.StreamFlowControlWindow).To(Equal(params.StreamFlowControlWindow))
|
||||
Expect(p.ConnectionFlowControlWindow).To(Equal(params.ConnectionFlowControlWindow))
|
||||
Expect(p.MaxUniStreams).To(Equal(params.MaxUniStreams))
|
||||
Expect(p.MaxBidiStreams).To(Equal(params.MaxBidiStreams))
|
||||
Expect(p.IdleTimeout).To(Equal(params.IdleTimeout))
|
||||
Expect(p.DisableMigration).To(Equal(params.DisableMigration))
|
||||
Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -9,12 +9,8 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/utils"
|
||||
"github.com/lucas-clemente/quic-go/qerr"
|
||||
)
|
||||
|
||||
// errMalformedTag is returned when the tag value cannot be read
|
||||
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
|
||||
|
||||
// TransportParameters are parameters sent to the peer during the handshake
|
||||
type TransportParameters struct {
|
||||
StreamFlowControlWindow protocol.ByteCount
|
||||
|
@ -22,78 +18,12 @@ type TransportParameters struct {
|
|||
|
||||
MaxPacketSize protocol.ByteCount
|
||||
|
||||
MaxUniStreams uint16 // only used for IETF QUIC
|
||||
MaxBidiStreams uint16 // only used for IETF QUIC
|
||||
MaxStreams uint32 // only used for gQUIC
|
||||
MaxUniStreams uint16
|
||||
MaxBidiStreams uint16
|
||||
|
||||
OmitConnectionID bool // only used for gQUIC
|
||||
IdleTimeout time.Duration
|
||||
DisableMigration bool // only used for IETF QUIC
|
||||
StatelessResetToken []byte // only used for IETF QUIC
|
||||
}
|
||||
|
||||
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
|
||||
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
|
||||
params := &TransportParameters{}
|
||||
if value, ok := tags[TagTCID]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
}
|
||||
params.OmitConnectionID = (v == 0)
|
||||
}
|
||||
if value, ok := tags[TagMIDS]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
}
|
||||
params.MaxStreams = v
|
||||
}
|
||||
if value, ok := tags[TagICSL]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
}
|
||||
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
|
||||
}
|
||||
if value, ok := tags[TagSFCW]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
}
|
||||
params.StreamFlowControlWindow = protocol.ByteCount(v)
|
||||
}
|
||||
if value, ok := tags[TagCFCW]; ok {
|
||||
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
|
||||
if err != nil {
|
||||
return nil, errMalformedTag
|
||||
}
|
||||
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
|
||||
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
|
||||
sfcw := bytes.NewBuffer([]byte{})
|
||||
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
|
||||
cfcw := bytes.NewBuffer([]byte{})
|
||||
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
|
||||
mids := bytes.NewBuffer([]byte{})
|
||||
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
|
||||
icsl := bytes.NewBuffer([]byte{})
|
||||
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
|
||||
|
||||
tags := map[Tag][]byte{
|
||||
TagICSL: icsl.Bytes(),
|
||||
TagMIDS: mids.Bytes(),
|
||||
TagCFCW: cfcw.Bytes(),
|
||||
TagSFCW: sfcw.Bytes(),
|
||||
}
|
||||
if p.OmitConnectionID {
|
||||
tags[TagTCID] = []byte{0, 0, 0, 0}
|
||||
}
|
||||
return tags
|
||||
DisableMigration bool
|
||||
StatelessResetToken []byte
|
||||
}
|
||||
|
||||
func (p *TransportParameters) unmarshal(data []byte) error {
|
||||
|
@ -209,7 +139,6 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) {
|
|||
}
|
||||
|
||||
// String returns a string representation, intended for logging.
|
||||
// It should only used for IETF QUIC.
|
||||
func (p *TransportParameters) String() string {
|
||||
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
|
||||
}
|
||||
|
|
|
@ -98,18 +98,6 @@ func (mr *MockSentPacketHandlerMockRecorder) GetPacketNumberLen(arg0 interface{}
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPacketNumberLen", reflect.TypeOf((*MockSentPacketHandler)(nil).GetPacketNumberLen), arg0)
|
||||
}
|
||||
|
||||
// GetStopWaitingFrame mocks base method
|
||||
func (m *MockSentPacketHandler) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame {
|
||||
ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0)
|
||||
ret0, _ := ret[0].(*wire.StopWaitingFrame)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame
|
||||
func (mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockSentPacketHandler)(nil).GetStopWaitingFrame), arg0)
|
||||
}
|
||||
|
||||
// OnAlarm mocks base method
|
||||
func (m *MockSentPacketHandler) OnAlarm() error {
|
||||
ret := m.ctrl.Call(m, "OnAlarm")
|
||||
|
|
149
internal/mocks/crypto_setup.go
Normal file
149
internal/mocks/crypto_setup.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: CryptoSetup)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
handshake "github.com/lucas-clemente/quic-go/internal/handshake"
|
||||
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// MockCryptoSetup is a mock of CryptoSetup interface
|
||||
type MockCryptoSetup struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCryptoSetupMockRecorder
|
||||
}
|
||||
|
||||
// MockCryptoSetupMockRecorder is the mock recorder for MockCryptoSetup
|
||||
type MockCryptoSetupMockRecorder struct {
|
||||
mock *MockCryptoSetup
|
||||
}
|
||||
|
||||
// NewMockCryptoSetup creates a new mock instance
|
||||
func NewMockCryptoSetup(ctrl *gomock.Controller) *MockCryptoSetup {
|
||||
mock := &MockCryptoSetup{ctrl: ctrl}
|
||||
mock.recorder = &MockCryptoSetupMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockCryptoSetup) EXPECT() *MockCryptoSetupMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockCryptoSetup) Close() error {
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close))
|
||||
}
|
||||
|
||||
// ConnectionState mocks base method
|
||||
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState {
|
||||
ret := m.ctrl.Call(m, "ConnectionState")
|
||||
ret0, _ := ret[0].(handshake.ConnectionState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ConnectionState indicates an expected call of ConnectionState
|
||||
func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState))
|
||||
}
|
||||
|
||||
// GetSealer mocks base method
|
||||
func (m *MockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) {
|
||||
ret := m.ctrl.Call(m, "GetSealer")
|
||||
ret0, _ := ret[0].(protocol.EncryptionLevel)
|
||||
ret1, _ := ret[1].(handshake.Sealer)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSealer indicates an expected call of GetSealer
|
||||
func (mr *MockCryptoSetupMockRecorder) GetSealer() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealer))
|
||||
}
|
||||
|
||||
// GetSealerWithEncryptionLevel mocks base method
|
||||
func (m *MockCryptoSetup) GetSealerWithEncryptionLevel(arg0 protocol.EncryptionLevel) (handshake.Sealer, error) {
|
||||
ret := m.ctrl.Call(m, "GetSealerWithEncryptionLevel", arg0)
|
||||
ret0, _ := ret[0].(handshake.Sealer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSealerWithEncryptionLevel indicates an expected call of GetSealerWithEncryptionLevel
|
||||
func (mr *MockCryptoSetupMockRecorder) GetSealerWithEncryptionLevel(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerWithEncryptionLevel", reflect.TypeOf((*MockCryptoSetup)(nil).GetSealerWithEncryptionLevel), arg0)
|
||||
}
|
||||
|
||||
// HandleMessage mocks base method
|
||||
func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool {
|
||||
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// HandleMessage indicates an expected call of HandleMessage
|
||||
func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1)
|
||||
}
|
||||
|
||||
// Open1RTT mocks base method
|
||||
func (m *MockCryptoSetup) Open1RTT(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||
ret := m.ctrl.Call(m, "Open1RTT", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Open1RTT indicates an expected call of Open1RTT
|
||||
func (mr *MockCryptoSetupMockRecorder) Open1RTT(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open1RTT", reflect.TypeOf((*MockCryptoSetup)(nil).Open1RTT), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// OpenHandshake mocks base method
|
||||
func (m *MockCryptoSetup) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||
ret := m.ctrl.Call(m, "OpenHandshake", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenHandshake indicates an expected call of OpenHandshake
|
||||
func (mr *MockCryptoSetupMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).OpenHandshake), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// OpenInitial mocks base method
|
||||
func (m *MockCryptoSetup) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) {
|
||||
ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenInitial indicates an expected call of OpenInitial
|
||||
func (mr *MockCryptoSetupMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockCryptoSetup)(nil).OpenInitial), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// RunHandshake mocks base method
|
||||
func (m *MockCryptoSetup) RunHandshake() error {
|
||||
ret := m.ctrl.Call(m, "RunHandshake")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RunHandshake indicates an expected call of RunHandshake
|
||||
func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake))
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package mocks
|
||||
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup"
|
||||
//go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
|
||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler"
|
||||
//go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/received_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler ReceivedPacketHandler"
|
||||
|
|
|
@ -7,30 +7,16 @@ type EncryptionLevel int
|
|||
const (
|
||||
// EncryptionUnspecified is a not specified encryption level
|
||||
EncryptionUnspecified EncryptionLevel = iota
|
||||
// EncryptionUnencrypted is not encrypted, for gQUIC
|
||||
EncryptionUnencrypted
|
||||
// EncryptionInitial is the Initial encryption level
|
||||
EncryptionInitial
|
||||
// EncryptionSecure is encrypted, but not forward secure
|
||||
EncryptionSecure
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake
|
||||
// EncryptionForwardSecure is forward secure
|
||||
EncryptionForwardSecure
|
||||
// Encryption1RTT is the 1-RTT encryption level
|
||||
Encryption1RTT
|
||||
)
|
||||
|
||||
func (e EncryptionLevel) String() string {
|
||||
switch e {
|
||||
// gQUIC
|
||||
case EncryptionUnencrypted:
|
||||
return "unencrypted"
|
||||
case EncryptionSecure:
|
||||
return "encrypted (not forward-secure)"
|
||||
case EncryptionForwardSecure:
|
||||
return "forward-secure"
|
||||
// IETF QUIC
|
||||
case EncryptionInitial:
|
||||
return "Initial"
|
||||
case EncryptionHandshake:
|
||||
|
|
|
@ -8,9 +8,6 @@ import (
|
|||
var _ = Describe("Encryption Level", func() {
|
||||
It("has the correct string representation", func() {
|
||||
Expect(EncryptionUnspecified.String()).To(Equal("unknown"))
|
||||
Expect(EncryptionUnencrypted.String()).To(Equal("unencrypted"))
|
||||
Expect(EncryptionSecure.String()).To(Equal("encrypted (not forward-secure)"))
|
||||
Expect(EncryptionForwardSecure.String()).To(Equal("forward-secure"))
|
||||
Expect(EncryptionInitial.String()).To(Equal("Initial"))
|
||||
Expect(EncryptionHandshake.String()).To(Equal("Handshake"))
|
||||
Expect(Encryption1RTT.String()).To(Equal("1-RTT"))
|
||||
|
|
|
@ -8,17 +8,13 @@ func InferPacketNumber(
|
|||
version VersionNumber,
|
||||
) PacketNumber {
|
||||
var epochDelta PacketNumber
|
||||
if version.UsesVarintPacketNumbers() {
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 7
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 14
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 30
|
||||
}
|
||||
} else {
|
||||
epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8)
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 7
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 14
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 30
|
||||
}
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
prevEpochBegin := epoch - epochDelta
|
||||
|
@ -48,8 +44,7 @@ func delta(a, b PacketNumber) PacketNumber {
|
|||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
|
||||
diff := uint64(packetNumber - leastUnacked)
|
||||
if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) ||
|
||||
!version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) {
|
||||
if diff < (1 << (14 - 1)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
return PacketNumberLen4
|
||||
|
|
|
@ -12,17 +12,15 @@ import (
|
|||
var _ = Describe("packet number calculation", func() {
|
||||
Context("infering a packet number", func() {
|
||||
getEpoch := func(len PacketNumberLen, v VersionNumber) uint64 {
|
||||
if v.UsesVarintPacketNumbers() {
|
||||
switch len {
|
||||
case PacketNumberLen1:
|
||||
return uint64(1) << 7
|
||||
case PacketNumberLen2:
|
||||
return uint64(1) << 14
|
||||
case PacketNumberLen4:
|
||||
return uint64(1) << 30
|
||||
default:
|
||||
Fail("invalid packet number len")
|
||||
}
|
||||
switch len {
|
||||
case PacketNumberLen1:
|
||||
return uint64(1) << 7
|
||||
case PacketNumberLen2:
|
||||
return uint64(1) << 14
|
||||
case PacketNumberLen4:
|
||||
return uint64(1) << 30
|
||||
default:
|
||||
Fail("invalid packet number len")
|
||||
}
|
||||
return uint64(1) << (len * 8)
|
||||
}
|
||||
|
@ -33,134 +31,128 @@ var _ = Describe("packet number calculation", func() {
|
|||
Expect(InferPacketNumber(length, PacketNumber(last), PacketNumber(wirePacketNumber), v)).To(Equal(PacketNumber(expected)))
|
||||
}
|
||||
|
||||
for _, v := range []VersionNumber{Version39, VersionTLS} {
|
||||
version := v
|
||||
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
|
||||
length := l
|
||||
|
||||
Context(fmt.Sprintf("using varint packet numbers: %t", version.UsesVarintPacketNumbers()), func() {
|
||||
for _, l := range []PacketNumberLen{PacketNumberLen1, PacketNumberLen2, PacketNumberLen4} {
|
||||
length := l
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
epoch := getEpoch(length, VersionWhatever)
|
||||
epochMask := epoch - 1
|
||||
|
||||
Context(fmt.Sprintf("with %d bytes", length), func() {
|
||||
epoch := getEpoch(length, version)
|
||||
epochMask := epoch - 1
|
||||
It("works near epoch start", func() {
|
||||
// A few quick manual sanity check
|
||||
check(length, 1, 0, VersionWhatever)
|
||||
check(length, epoch+1, epochMask, VersionWhatever)
|
||||
check(length, epoch, epochMask, VersionWhatever)
|
||||
|
||||
It("works near epoch start", func() {
|
||||
// A few quick manual sanity check
|
||||
check(length, 1, 0, version)
|
||||
check(length, epoch+1, epochMask, version)
|
||||
check(length, epoch, epochMask, version)
|
||||
// Cases where the last number was close to the start of the range.
|
||||
for last := uint64(0); last < 10; last++ {
|
||||
// Small numbers should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, j, last, VersionWhatever)
|
||||
}
|
||||
|
||||
// Cases where the last number was close to the start of the range.
|
||||
for last := uint64(0); last < 10; last++ {
|
||||
// Small numbers should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, j, last, version)
|
||||
}
|
||||
// Large numbers should not wrap either (because we're near 0 already).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last, VersionWhatever)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Large numbers should not wrap either (because we're near 0 already).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
It("works near epoch end", func() {
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := epoch - i
|
||||
|
||||
It("works near epoch end", func() {
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := epoch - i
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch+j, last, VersionWhatever)
|
||||
}
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch+j, last, version)
|
||||
}
|
||||
// Large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last, VersionWhatever)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, epoch-1-j, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
// Next check where we're in a non-zero epoch to verify we handle
|
||||
// reverse wrapping, too.
|
||||
It("works near previous epoch", func() {
|
||||
prevEpoch := 1 * epoch
|
||||
curEpoch := 2 * epoch
|
||||
// Cases where the last number was close to the start of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := curEpoch + i
|
||||
// Small number should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, curEpoch+j, last, VersionWhatever)
|
||||
}
|
||||
|
||||
// Next check where we're in a non-zero epoch to verify we handle
|
||||
// reverse wrapping, too.
|
||||
It("works near previous epoch", func() {
|
||||
prevEpoch := 1 * epoch
|
||||
curEpoch := 2 * epoch
|
||||
// Cases where the last number was close to the start of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := curEpoch + i
|
||||
// Small number should not wrap (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, curEpoch+j, last, version)
|
||||
}
|
||||
// But large numbers should reverse wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, prevEpoch+num, last, VersionWhatever)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// But large numbers should reverse wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, prevEpoch+num, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
It("works near next epoch", func() {
|
||||
curEpoch := 2 * epoch
|
||||
nextEpoch := 3 * epoch
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := nextEpoch - 1 - i
|
||||
|
||||
It("works near next epoch", func() {
|
||||
curEpoch := 2 * epoch
|
||||
nextEpoch := 3 * epoch
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
last := nextEpoch - 1 - i
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, nextEpoch+j, last, VersionWhatever)
|
||||
}
|
||||
|
||||
// Small numbers should wrap.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, nextEpoch+j, last, version)
|
||||
}
|
||||
// but large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, curEpoch+num, last, VersionWhatever)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// but large numbers should not (even if they're out of order).
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, curEpoch+num, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
It("works near next max", func() {
|
||||
maxNumber := uint64(math.MaxUint64)
|
||||
maxEpoch := maxNumber & ^epochMask
|
||||
|
||||
It("works near next max", func() {
|
||||
maxNumber := uint64(math.MaxUint64)
|
||||
maxEpoch := maxNumber & ^epochMask
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
// Subtract 1, because the expected next packet number is 1 more than the
|
||||
// last packet number.
|
||||
last := maxNumber - i - 1
|
||||
|
||||
// Cases where the last number was close to the end of the range
|
||||
for i := uint64(0); i < 10; i++ {
|
||||
// Subtract 1, because the expected next packet number is 1 more than the
|
||||
// last packet number.
|
||||
last := maxNumber - i - 1
|
||||
// Small numbers should not wrap, because they have nowhere to go.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, maxEpoch+j, last, VersionWhatever)
|
||||
}
|
||||
|
||||
// Small numbers should not wrap, because they have nowhere to go.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
check(length, maxEpoch+j, last, version)
|
||||
}
|
||||
|
||||
// Large numbers should not wrap either.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, maxEpoch+num, last, version)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
// Large numbers should not wrap either.
|
||||
for j := uint64(0); j < 10; j++ {
|
||||
num := epoch - 1 - j
|
||||
check(length, maxEpoch+num, last, VersionWhatever)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
Context("shortening a packet number for the header", func() {
|
||||
Context("shortening", func() {
|
||||
It("sends out low packet numbers as 2 byte", func() {
|
||||
length := GetPacketNumberLengthForHeader(4, 2, version)
|
||||
length := GetPacketNumberLengthForHeader(4, 2, VersionWhatever)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("sends out high packet numbers as 2 byte, if all ACKs are received", func() {
|
||||
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, version)
|
||||
length := GetPacketNumberLengthForHeader(0xdeadbeef, 0xdeadbeef-1, VersionWhatever)
|
||||
Expect(length).To(Equal(PacketNumberLen2))
|
||||
})
|
||||
|
||||
It("sends out higher packet numbers as 4 bytes, if a lot of ACKs are missing", func() {
|
||||
length := GetPacketNumberLengthForHeader(40000, 2, version)
|
||||
length := GetPacketNumberLengthForHeader(40000, 2, VersionWhatever)
|
||||
Expect(length).To(Equal(PacketNumberLen4))
|
||||
})
|
||||
})
|
||||
|
@ -170,10 +162,10 @@ var _ = Describe("packet number calculation", func() {
|
|||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
@ -182,28 +174,28 @@ var _ = Describe("packet number calculation", func() {
|
|||
for i := uint64(1); i < 10000; i++ {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(i / 2)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
||||
epochMask := getEpoch(length, version) - 1
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
|
||||
epochMask := getEpoch(length, VersionWhatever) - 1
|
||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
||||
It("also works for larger packet numbers", func() {
|
||||
var increment uint64
|
||||
for i := uint64(1); i < getEpoch(PacketNumberLen4, version); i += increment {
|
||||
for i := uint64(1); i < getEpoch(PacketNumberLen4, VersionWhatever); i += increment {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(1)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
||||
epochMask := getEpoch(length, version) - 1
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
|
||||
epochMask := getEpoch(length, VersionWhatever) - 1
|
||||
wirePacketNumber := uint64(packetNumber) & epochMask
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
|
||||
increment = getEpoch(length, version) / 8
|
||||
increment = getEpoch(length, VersionWhatever) / 8
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -211,10 +203,10 @@ var _ = Describe("packet number calculation", func() {
|
|||
for i := (uint64(1) << 48); i < ((uint64(1) << 63) - 1); i += (uint64(1) << 48) {
|
||||
packetNumber := PacketNumber(i)
|
||||
leastUnacked := PacketNumber(i - 1000)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, version)
|
||||
length := GetPacketNumberLengthForHeader(packetNumber, leastUnacked, VersionWhatever)
|
||||
wirePacketNumber := (uint64(packetNumber) << (64 - length*8)) >> (64 - length*8)
|
||||
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), version)
|
||||
inferedPacketNumber := InferPacketNumber(length, leastUnacked, PacketNumber(wirePacketNumber), VersionWhatever)
|
||||
Expect(inferedPacketNumber).To(Equal(packetNumber))
|
||||
}
|
||||
})
|
||||
|
|
|
@ -23,7 +23,7 @@ const (
|
|||
PacketNumberLen6 PacketNumberLen = 6
|
||||
)
|
||||
|
||||
// The PacketType is the Long Header Type (only used for the IETF draft header format)
|
||||
// The PacketType is the Long Header Type
|
||||
type PacketType uint8
|
||||
|
||||
const (
|
||||
|
@ -71,10 +71,7 @@ const MaxReceivePacketSize ByteCount = 1452
|
|||
// Used in QUIC for congestion window computations in bytes.
|
||||
const DefaultTCPMSS ByteCount = 1460
|
||||
|
||||
// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC)
|
||||
const MinClientHelloSize = 1024
|
||||
|
||||
// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is required to have.
|
||||
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
|
||||
const MinInitialPacketSize = 1200
|
||||
|
||||
// MaxClientHellos is the maximum number of times we'll send a client hello
|
||||
|
@ -83,8 +80,5 @@ const MinInitialPacketSize = 1200
|
|||
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
|
||||
const MaxClientHellos = 3
|
||||
|
||||
// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets.
|
||||
const ConnectionIDLenGQUIC = 8
|
||||
|
||||
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
||||
const MinConnectionIDLenInitial = 8
|
||||
|
|
|
@ -24,10 +24,6 @@ const InitialCongestionWindow ByteCount = 32 * DefaultTCPMSS
|
|||
// session queues for later until it sends a public reset.
|
||||
const MaxUndecryptablePackets = 10
|
||||
|
||||
// PublicResetTimeout is the time to wait before sending a Public Reset when receiving too many undecryptable packets during the handshake
|
||||
// This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto
|
||||
const PublicResetTimeout = 500 * time.Millisecond
|
||||
|
||||
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
|
||||
// This is the value that Google servers are using
|
||||
const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB
|
||||
|
@ -65,12 +61,6 @@ const DefaultMaxIncomingStreams = 100
|
|||
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
|
||||
const DefaultMaxIncomingUniStreams = 100
|
||||
|
||||
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
|
||||
const MaxStreamsMultiplier = 1.1
|
||||
|
||||
// MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used.
|
||||
const MaxStreamsMinimumIncrement = 10
|
||||
|
||||
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
|
||||
const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets
|
||||
|
||||
|
@ -103,20 +93,10 @@ const MaxNonRetransmittableAcks = 19
|
|||
// prevents DoS attacks against the streamFrameSorter
|
||||
const MaxStreamFrameSorterGaps = 1000
|
||||
|
||||
// CryptoMaxParams is the upper limit for the number of parameters in a crypto message.
|
||||
// Value taken from Chrome.
|
||||
const CryptoMaxParams = 128
|
||||
|
||||
// CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message.
|
||||
const CryptoParameterMaxLength = 4000
|
||||
|
||||
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
|
||||
// This limits the size of the ClientHello and Certificates that can be received.
|
||||
const MaxCryptoStreamOffset = 16 * (1 << 10)
|
||||
|
||||
// EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX.
|
||||
const EphermalKeyLifetime = time.Minute
|
||||
|
||||
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
|
||||
const MinRemoteIdleTimeout = 5 * time.Second
|
||||
|
||||
|
@ -130,16 +110,13 @@ const DefaultHandshakeTimeout = 10 * time.Second
|
|||
// after this time all information about the old connection will be deleted
|
||||
const ClosedSessionDeleteTimeout = time.Minute
|
||||
|
||||
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
|
||||
const NumCachedCertificates = 128
|
||||
|
||||
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
|
||||
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
|
||||
// 1. it reduces the framing overhead
|
||||
// 2. it reduces the head-of-line blocking, when a packet is lost
|
||||
const MinStreamFrameSize ByteCount = 128
|
||||
|
||||
// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write
|
||||
// MaxAckFrameSize is the maximum size for an ACK frame that we write
|
||||
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
||||
// The MaxAckFrameSize should be large enough to encode many ACK range,
|
||||
// but must ensure that a maximum size ACK frame fits into one packet.
|
||||
|
|
|
@ -5,7 +5,6 @@ type StreamID uint64
|
|||
|
||||
// MaxBidiStreamID is the highest stream ID that the peer is allowed to open,
|
||||
// when it is allowed to open numStreams bidirectional streams.
|
||||
// It is only valid for IETF QUIC.
|
||||
func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
|
||||
if numStreams == 0 {
|
||||
return 0
|
||||
|
@ -21,7 +20,6 @@ func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
|
|||
|
||||
// MaxUniStreamID is the highest stream ID that the peer is allowed to open,
|
||||
// when it is allowed to open numStreams unidirectional streams.
|
||||
// It is only valid for IETF QUIC.
|
||||
func MaxUniStreamID(numStreams int, pers Perspective) StreamID {
|
||||
if numStreams == 0 {
|
||||
return 0
|
||||
|
|
|
@ -18,32 +18,20 @@ const (
|
|||
|
||||
// The version numbers, making grepping easier
|
||||
const (
|
||||
Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9
|
||||
Version43 VersionNumber = gquicVersion0 + 4*0x100 + 0x3
|
||||
Version44 VersionNumber = gquicVersion0 + 4*0x100 + 0x4
|
||||
VersionTLS VersionNumber = 101
|
||||
VersionWhatever VersionNumber = 0 // for when the version doesn't matter
|
||||
VersionWhatever VersionNumber = 1 // for when the version doesn't matter
|
||||
VersionUnknown VersionNumber = math.MaxUint32
|
||||
)
|
||||
|
||||
// SupportedVersions lists the versions that the server supports
|
||||
// must be in sorted descending order
|
||||
var SupportedVersions = []VersionNumber{
|
||||
Version44,
|
||||
Version43,
|
||||
Version39,
|
||||
}
|
||||
var SupportedVersions = []VersionNumber{VersionTLS}
|
||||
|
||||
// IsValidVersion says if the version is known to quic-go
|
||||
func IsValidVersion(v VersionNumber) bool {
|
||||
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
|
||||
}
|
||||
|
||||
// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake
|
||||
func (vn VersionNumber) UsesTLS() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
func (vn VersionNumber) String() string {
|
||||
switch vn {
|
||||
case VersionWhatever:
|
||||
|
@ -62,56 +50,9 @@ func (vn VersionNumber) String() string {
|
|||
|
||||
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
|
||||
func (vn VersionNumber) ToAltSvc() string {
|
||||
if vn.isGQUIC() {
|
||||
return fmt.Sprintf("%d", vn.toGQUICVersion())
|
||||
}
|
||||
return fmt.Sprintf("%d", vn)
|
||||
}
|
||||
|
||||
// IsCryptoStream says if a stream is the gQUIC crypto stream.
|
||||
// It never returns true for IETF QUIC.
|
||||
func (vn VersionNumber) IsCryptoStream(id StreamID) bool {
|
||||
return vn.isGQUIC() && id == 1
|
||||
}
|
||||
|
||||
// UsesIETFFrameFormat tells if this version uses the IETF frame format
|
||||
func (vn VersionNumber) UsesIETFFrameFormat() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// UsesIETFHeaderFormat tells if this version uses the IETF header format
|
||||
func (vn VersionNumber) UsesIETFHeaderFormat() bool {
|
||||
return !vn.isGQUIC() || vn >= Version44
|
||||
}
|
||||
|
||||
// UsesLengthInHeader tells if this version uses the Length field in the IETF header
|
||||
func (vn VersionNumber) UsesLengthInHeader() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// UsesTokenInHeader tells if this version uses the Token field in the IETF header
|
||||
func (vn VersionNumber) UsesTokenInHeader() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// UsesStopWaitingFrames tells if this version uses STOP_WAITING frames
|
||||
func (vn VersionNumber) UsesStopWaitingFrames() bool {
|
||||
return vn.isGQUIC() && vn <= Version43
|
||||
}
|
||||
|
||||
// UsesVarintPacketNumbers tells if this version uses 7/14/30 bit packet numbers
|
||||
func (vn VersionNumber) UsesVarintPacketNumbers() bool {
|
||||
return !vn.isGQUIC()
|
||||
}
|
||||
|
||||
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
|
||||
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
|
||||
if !vn.isGQUIC() {
|
||||
return true
|
||||
}
|
||||
return id != 1 && id != 3
|
||||
}
|
||||
|
||||
func (vn VersionNumber) isGQUIC() bool {
|
||||
return vn > gquicVersion0 && vn <= maxGquicVersion
|
||||
}
|
||||
|
|
|
@ -10,39 +10,18 @@ var _ = Describe("Version", func() {
|
|||
return v&0x0f0f0f0f == 0x0a0a0a0a
|
||||
}
|
||||
|
||||
// version numbers taken from the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
|
||||
It("has the right gQUIC version number", func() {
|
||||
Expect(Version39).To(BeEquivalentTo(0x51303339))
|
||||
Expect(Version43).To(BeEquivalentTo(0x51303433))
|
||||
Expect(Version44).To(BeEquivalentTo(0x51303434))
|
||||
})
|
||||
|
||||
It("says if a version is valid", func() {
|
||||
Expect(IsValidVersion(Version39)).To(BeTrue())
|
||||
Expect(IsValidVersion(Version43)).To(BeTrue())
|
||||
Expect(IsValidVersion(Version44)).To(BeTrue())
|
||||
Expect(IsValidVersion(VersionTLS)).To(BeTrue())
|
||||
Expect(IsValidVersion(VersionWhatever)).To(BeFalse())
|
||||
Expect(IsValidVersion(VersionUnknown)).To(BeFalse())
|
||||
Expect(IsValidVersion(1234)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("says if a version supports TLS", func() {
|
||||
Expect(Version39.UsesTLS()).To(BeFalse())
|
||||
Expect(Version43.UsesTLS()).To(BeFalse())
|
||||
Expect(Version44.UsesTLS()).To(BeFalse())
|
||||
Expect(VersionTLS.UsesTLS()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("versions don't have reserved version numbers", func() {
|
||||
Expect(isReservedVersion(Version39)).To(BeFalse())
|
||||
Expect(isReservedVersion(Version43)).To(BeFalse())
|
||||
Expect(isReservedVersion(Version44)).To(BeFalse())
|
||||
Expect(isReservedVersion(VersionTLS)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("has the right string representation", func() {
|
||||
Expect(Version39.String()).To(Equal("gQUIC 39"))
|
||||
Expect(VersionTLS.String()).To(ContainSubstring("TLS"))
|
||||
Expect(VersionWhatever.String()).To(Equal("whatever"))
|
||||
Expect(VersionUnknown.String()).To(Equal("unknown"))
|
||||
|
@ -55,88 +34,7 @@ var _ = Describe("Version", func() {
|
|||
})
|
||||
|
||||
It("has the right representation for the H2 Alt-Svc tag", func() {
|
||||
Expect(Version39.ToAltSvc()).To(Equal("39"))
|
||||
Expect(Version43.ToAltSvc()).To(Equal("43"))
|
||||
Expect(Version44.ToAltSvc()).To(Equal("44"))
|
||||
Expect(VersionTLS.ToAltSvc()).To(Equal("101"))
|
||||
// check with unsupported version numbers from the wiki
|
||||
Expect(VersionNumber(0x51303133).ToAltSvc()).To(Equal("13"))
|
||||
Expect(VersionNumber(0x51303235).ToAltSvc()).To(Equal("25"))
|
||||
Expect(VersionNumber(0x51303438).ToAltSvc()).To(Equal("48"))
|
||||
})
|
||||
|
||||
It("says if a stream is the crypto stream, for gQUIC", func() {
|
||||
for _, v := range []VersionNumber{Version39, Version43, Version44} {
|
||||
version := v
|
||||
Expect(version.IsCryptoStream(1)).To(BeTrue())
|
||||
Expect(version.IsCryptoStream(2)).To(BeFalse())
|
||||
Expect(version.IsCryptoStream(3)).To(BeFalse())
|
||||
Expect(version.IsCryptoStream(4)).To(BeFalse())
|
||||
Expect(version.IsCryptoStream(5)).To(BeFalse())
|
||||
}
|
||||
})
|
||||
|
||||
It("says if a stream is the crypto stream, for TLS", func() {
|
||||
// all streams contribute to connection-level flow control
|
||||
for id := StreamID(0); id < 10; id++ {
|
||||
Expect(VersionTLS.IsCryptoStream(id)).To(BeFalse())
|
||||
}
|
||||
})
|
||||
|
||||
It("tells if a version uses the IETF frame types", func() {
|
||||
Expect(Version39.UsesIETFFrameFormat()).To(BeFalse())
|
||||
Expect(Version43.UsesIETFFrameFormat()).To(BeFalse())
|
||||
Expect(Version44.UsesIETFFrameFormat()).To(BeFalse())
|
||||
Expect(VersionTLS.UsesIETFFrameFormat()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("tells if a version uses the IETF header format", func() {
|
||||
Expect(Version39.UsesIETFHeaderFormat()).To(BeFalse())
|
||||
Expect(Version43.UsesIETFHeaderFormat()).To(BeFalse())
|
||||
Expect(Version44.UsesIETFHeaderFormat()).To(BeTrue())
|
||||
Expect(VersionTLS.UsesIETFHeaderFormat()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("tells if a version uses varint packet numbers", func() {
|
||||
Expect(Version39.UsesVarintPacketNumbers()).To(BeFalse())
|
||||
Expect(Version43.UsesVarintPacketNumbers()).To(BeFalse())
|
||||
Expect(Version44.UsesVarintPacketNumbers()).To(BeFalse())
|
||||
Expect(VersionTLS.UsesVarintPacketNumbers()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("tells if a version uses the Length field in the IETF header", func() {
|
||||
Expect(Version44.UsesLengthInHeader()).To(BeFalse())
|
||||
Expect(VersionTLS.UsesLengthInHeader()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("tells if a version uses the Token field in the IETF header", func() {
|
||||
Expect(Version44.UsesTokenInHeader()).To(BeFalse())
|
||||
Expect(VersionTLS.UsesTokenInHeader()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("tells if a version uses STOP_WAITING frames", func() {
|
||||
Expect(Version39.UsesStopWaitingFrames()).To(BeTrue())
|
||||
Expect(Version43.UsesStopWaitingFrames()).To(BeTrue())
|
||||
Expect(Version44.UsesStopWaitingFrames()).To(BeFalse())
|
||||
Expect(VersionTLS.UsesStopWaitingFrames()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("says if a stream contributes to connection-level flowcontrol, for gQUIC", func() {
|
||||
for _, v := range []VersionNumber{Version39, Version43, Version44} {
|
||||
version := v
|
||||
Expect(version.StreamContributesToConnectionFlowControl(1)).To(BeFalse())
|
||||
Expect(version.StreamContributesToConnectionFlowControl(2)).To(BeTrue())
|
||||
Expect(version.StreamContributesToConnectionFlowControl(3)).To(BeFalse())
|
||||
Expect(version.StreamContributesToConnectionFlowControl(4)).To(BeTrue())
|
||||
Expect(version.StreamContributesToConnectionFlowControl(5)).To(BeTrue())
|
||||
}
|
||||
})
|
||||
|
||||
It("says if a stream contributes to connection-level flowcontrol, for TLS", func() {
|
||||
// all streams contribute to connection-level flow control
|
||||
for id := StreamID(0); id < 10; id++ {
|
||||
Expect(VersionTLS.StreamContributesToConnectionFlowControl(id)).To(BeTrue())
|
||||
}
|
||||
})
|
||||
|
||||
It("recognizes supported versions", func() {
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
package utils
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// An AtomicBool is an atomic bool
|
||||
type AtomicBool struct {
|
||||
v int32
|
||||
}
|
||||
|
||||
// Set sets the value
|
||||
func (a *AtomicBool) Set(value bool) {
|
||||
var n int32
|
||||
if value {
|
||||
n = 1
|
||||
}
|
||||
atomic.StoreInt32(&a.v, n)
|
||||
}
|
||||
|
||||
// Get gets the value
|
||||
func (a *AtomicBool) Get() bool {
|
||||
return atomic.LoadInt32(&a.v) != 0
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Atomic Bool", func() {
|
||||
var a *AtomicBool
|
||||
|
||||
BeforeEach(func() {
|
||||
a = &AtomicBool{}
|
||||
})
|
||||
|
||||
It("has the right default value", func() {
|
||||
Expect(a.Get()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("sets the value to true", func() {
|
||||
a.Set(true)
|
||||
Expect(a.Get()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("sets the value to false", func() {
|
||||
a.Set(true)
|
||||
a.Set(false)
|
||||
Expect(a.Get()).To(BeFalse())
|
||||
})
|
||||
})
|
|
@ -75,27 +75,6 @@ var _ = Describe("Big Endian encoding / decoding", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("WriteUint24", func() {
|
||||
It("outputs 3 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
BigEndian.WriteUint24(b, uint32(1))
|
||||
Expect(b.Len()).To(Equal(3))
|
||||
})
|
||||
|
||||
It("outputs a big endian", func() {
|
||||
num := uint32(0x010203)
|
||||
b := &bytes.Buffer{}
|
||||
BigEndian.WriteUint24(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x01, 0x02, 0x03}))
|
||||
})
|
||||
|
||||
It("panics if the value doesn't fit into 24 bits", func() {
|
||||
num := uint32(0x01020304)
|
||||
b := &bytes.Buffer{}
|
||||
Expect(func() { BigEndian.WriteUint24(b, num) }).Should(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint32", func() {
|
||||
It("outputs 4 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
|
@ -111,69 +90,6 @@ var _ = Describe("Big Endian encoding / decoding", func() {
|
|||
})
|
||||
})
|
||||
|
||||
Context("WriteUint40", func() {
|
||||
It("outputs 5 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
BigEndian.WriteUint40(b, uint64(1))
|
||||
Expect(b.Len()).To(Equal(5))
|
||||
})
|
||||
|
||||
It("outputs a big endian", func() {
|
||||
num := uint64(0xDECAFBAD42)
|
||||
b := &bytes.Buffer{}
|
||||
BigEndian.WriteUint40(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0xDE, 0xCA, 0xFB, 0xAD, 0x42}))
|
||||
})
|
||||
|
||||
It("panics if the value doesn't fit into 40 bits", func() {
|
||||
num := uint64(0x010203040506)
|
||||
b := &bytes.Buffer{}
|
||||
Expect(func() { BigEndian.WriteUint40(b, num) }).Should(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint48", func() {
|
||||
It("outputs 6 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
BigEndian.WriteUint48(b, uint64(1))
|
||||
Expect(b.Len()).To(Equal(6))
|
||||
})
|
||||
|
||||
It("outputs a big endian", func() {
|
||||
num := uint64(0xDEADBEEFCAFE)
|
||||
b := &bytes.Buffer{}
|
||||
BigEndian.WriteUint48(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE}))
|
||||
})
|
||||
|
||||
It("panics if the value doesn't fit into 48 bits", func() {
|
||||
num := uint64(0xDEADBEEFCAFE01)
|
||||
b := &bytes.Buffer{}
|
||||
Expect(func() { BigEndian.WriteUint48(b, num) }).Should(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint56", func() {
|
||||
It("outputs 7 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
BigEndian.WriteUint56(b, uint64(1))
|
||||
Expect(b.Len()).To(Equal(7))
|
||||
})
|
||||
|
||||
It("outputs a big endian", func() {
|
||||
num := uint64(0xEEDDCCBBAA9988)
|
||||
b := &bytes.Buffer{}
|
||||
BigEndian.WriteUint56(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0xEE, 0xDD, 0xCC, 0xBB, 0xAA, 0x99, 0x88}))
|
||||
})
|
||||
|
||||
It("panics if the value doesn't fit into 56 bits", func() {
|
||||
num := uint64(0xEEDDCCBBAA998801)
|
||||
b := &bytes.Buffer{}
|
||||
Expect(func() { BigEndian.WriteUint56(b, num) }).Should(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint64", func() {
|
||||
It("outputs 8 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
|
|
|
@ -13,13 +13,6 @@ type ByteOrder interface {
|
|||
ReadUint16(io.ByteReader) (uint16, error)
|
||||
|
||||
WriteUint64(*bytes.Buffer, uint64)
|
||||
WriteUint56(*bytes.Buffer, uint64)
|
||||
WriteUint48(*bytes.Buffer, uint64)
|
||||
WriteUint40(*bytes.Buffer, uint64)
|
||||
WriteUint32(*bytes.Buffer, uint32)
|
||||
WriteUint24(*bytes.Buffer, uint32)
|
||||
WriteUint16(*bytes.Buffer, uint16)
|
||||
|
||||
ReadUfloat16(io.ByteReader) (uint64, error)
|
||||
WriteUfloat16(*bytes.Buffer, uint64)
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package utils
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
|
@ -97,61 +96,12 @@ func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
|
|||
})
|
||||
}
|
||||
|
||||
// WriteUint56 writes 56 bit of a uint64
|
||||
func (bigEndian) WriteUint56(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 56) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
|
||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint48 writes 48 bit of a uint64
|
||||
func (bigEndian) WriteUint48(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 48) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i >> 40), uint8(i >> 32),
|
||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint40 writes 40 bit of a uint64
|
||||
func (bigEndian) WriteUint40(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 40) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i >> 32),
|
||||
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint32 writes a uint32
|
||||
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
||||
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
// WriteUint24 writes 24 bit of a uint32
|
||||
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
|
||||
if i >= (1 << 24) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
|
||||
}
|
||||
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
// WriteUint16 writes a uint16
|
||||
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
||||
b.Write([]byte{uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
func (l bigEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
|
||||
return readUfloat16(b, l)
|
||||
}
|
||||
|
||||
func (l bigEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
|
||||
writeUfloat16(b, l, val)
|
||||
}
|
||||
|
|
|
@ -1,157 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// LittleEndian is the little-endian implementation of ByteOrder.
|
||||
var LittleEndian ByteOrder = littleEndian{}
|
||||
|
||||
type littleEndian struct{}
|
||||
|
||||
var _ ByteOrder = &littleEndian{}
|
||||
|
||||
// ReadUintN reads N bytes
|
||||
func (littleEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
|
||||
var res uint64
|
||||
for i := uint8(0); i < length; i++ {
|
||||
bt, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res ^= uint64(bt) << (i * 8)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// ReadUint64 reads a uint64
|
||||
func (littleEndian) ReadUint64(b io.ByteReader) (uint64, error) {
|
||||
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
|
||||
var err error
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b4, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b5, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b6, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b7, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b8, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
|
||||
}
|
||||
|
||||
// ReadUint32 reads a uint32
|
||||
func (littleEndian) ReadUint32(b io.ByteReader) (uint32, error) {
|
||||
var b1, b2, b3, b4 uint8
|
||||
var err error
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b4, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
|
||||
}
|
||||
|
||||
// ReadUint16 reads a uint16
|
||||
func (littleEndian) ReadUint16(b io.ByteReader) (uint16, error) {
|
||||
var b1, b2 uint8
|
||||
var err error
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(b1) + uint16(b2)<<8, nil
|
||||
}
|
||||
|
||||
// WriteUint64 writes a uint64
|
||||
func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) {
|
||||
b.Write([]byte{
|
||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
||||
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), uint8(i >> 56),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint56 writes 56 bit of a uint64
|
||||
func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 56) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
||||
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint48 writes 48 bit of a uint64
|
||||
func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 48) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
|
||||
uint8(i >> 32), uint8(i >> 40),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint40 writes 40 bit of a uint64
|
||||
func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) {
|
||||
if i >= (1 << 40) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
|
||||
}
|
||||
b.Write([]byte{
|
||||
uint8(i), uint8(i >> 8), uint8(i >> 16),
|
||||
uint8(i >> 24), uint8(i >> 32),
|
||||
})
|
||||
}
|
||||
|
||||
// WriteUint32 writes a uint32
|
||||
func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
||||
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24)})
|
||||
}
|
||||
|
||||
// WriteUint24 writes 24 bit of a uint32
|
||||
func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) {
|
||||
if i >= (1 << 24) {
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
|
||||
}
|
||||
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)})
|
||||
}
|
||||
|
||||
// WriteUint16 writes a uint16
|
||||
func (littleEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
||||
b.Write([]byte{uint8(i), uint8(i >> 8)})
|
||||
}
|
||||
|
||||
func (l littleEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
|
||||
return readUfloat16(b, l)
|
||||
}
|
||||
|
||||
func (l littleEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
|
||||
writeUfloat16(b, l, val)
|
||||
}
|
|
@ -1,212 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Little Endian encoding / decoding", func() {
|
||||
Context("ReadUint16", func() {
|
||||
It("reads a little endian", func() {
|
||||
b := []byte{0x13, 0xEF}
|
||||
val, err := LittleEndian.ReadUint16(bytes.NewReader(b))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(val).To(Equal(uint16(0xEF13)))
|
||||
})
|
||||
|
||||
It("throws an error if less than 2 bytes are passed", func() {
|
||||
b := []byte{0x13, 0xEF}
|
||||
for i := 0; i < len(b); i++ {
|
||||
_, err := LittleEndian.ReadUint16(bytes.NewReader(b[:i]))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("ReadUint32", func() {
|
||||
It("reads a little endian", func() {
|
||||
b := []byte{0x12, 0x35, 0xAB, 0xFF}
|
||||
val, err := LittleEndian.ReadUint32(bytes.NewReader(b))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(val).To(Equal(uint32(0xFFAB3512)))
|
||||
})
|
||||
|
||||
It("throws an error if less than 4 bytes are passed", func() {
|
||||
b := []byte{0x12, 0x35, 0xAB, 0xFF}
|
||||
for i := 0; i < len(b); i++ {
|
||||
_, err := LittleEndian.ReadUint32(bytes.NewReader(b[:i]))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("ReadUint64", func() {
|
||||
It("reads a little endian", func() {
|
||||
b := []byte{0x12, 0x35, 0xAB, 0xFF, 0xEF, 0xBE, 0xAD, 0xDE}
|
||||
val, err := LittleEndian.ReadUint64(bytes.NewReader(b))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(val).To(Equal(uint64(0xDEADBEEFFFAB3512)))
|
||||
})
|
||||
|
||||
It("throws an error if less than 8 bytes are passed", func() {
|
||||
b := []byte{0x12, 0x35, 0xAB, 0xFF, 0xEF, 0xBE, 0xAD, 0xDE}
|
||||
for i := 0; i < len(b); i++ {
|
||||
_, err := LittleEndian.ReadUint64(bytes.NewReader(b[:i]))
|
||||
Expect(err).To(MatchError(io.EOF))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint16", func() {
|
||||
It("outputs 2 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint16(b, uint16(1))
|
||||
Expect(b.Len()).To(Equal(2))
|
||||
})
|
||||
|
||||
It("outputs a little endian", func() {
|
||||
num := uint16(0xFF11)
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint16(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x11, 0xFF}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint24", func() {
|
||||
It("outputs 3 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint24(b, uint32(1))
|
||||
Expect(b.Len()).To(Equal(3))
|
||||
})
|
||||
|
||||
It("outputs a little endian", func() {
|
||||
num := uint32(0x010203)
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint24(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x03, 0x02, 0x01}))
|
||||
})
|
||||
|
||||
It("panics if the value doesn't fit into 24 bits", func() {
|
||||
num := uint32(0x01020304)
|
||||
b := &bytes.Buffer{}
|
||||
Expect(func() { LittleEndian.WriteUint24(b, num) }).Should(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint32", func() {
|
||||
It("outputs 4 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint32(b, uint32(1))
|
||||
Expect(b.Len()).To(Equal(4))
|
||||
})
|
||||
|
||||
It("outputs a little endian", func() {
|
||||
num := uint32(0xEFAC3512)
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint32(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x12, 0x35, 0xAC, 0xEF}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint40", func() {
|
||||
It("outputs 5 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint40(b, uint64(1))
|
||||
Expect(b.Len()).To(Equal(5))
|
||||
})
|
||||
|
||||
It("outputs a little endian", func() {
|
||||
num := uint64(0x0102030405)
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint40(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x05, 0x04, 0x03, 0x02, 0x01}))
|
||||
})
|
||||
|
||||
It("panics if the value doesn't fit into 40 bits", func() {
|
||||
num := uint64(0x010203040506)
|
||||
b := &bytes.Buffer{}
|
||||
Expect(func() { LittleEndian.WriteUint40(b, num) }).Should(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint48", func() {
|
||||
It("outputs 6 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint48(b, uint64(1))
|
||||
Expect(b.Len()).To(Equal(6))
|
||||
})
|
||||
|
||||
It("outputs a little endian", func() {
|
||||
num := uint64(0xDEADBEEFCAFE)
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint48(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0xFE, 0xCA, 0xEF, 0xBE, 0xAD, 0xDE}))
|
||||
})
|
||||
|
||||
It("panics if the value doesn't fit into 48 bits", func() {
|
||||
num := uint64(0xDEADBEEFCAFE01)
|
||||
b := &bytes.Buffer{}
|
||||
Expect(func() { LittleEndian.WriteUint48(b, num) }).Should(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint56", func() {
|
||||
It("outputs 7 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint56(b, uint64(1))
|
||||
Expect(b.Len()).To(Equal(7))
|
||||
})
|
||||
|
||||
It("outputs a little endian", func() {
|
||||
num := uint64(0xEEDDCCBBAA9988)
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint56(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE}))
|
||||
})
|
||||
|
||||
It("panics if the value doesn't fit into 56 bits", func() {
|
||||
num := uint64(0xEEDDCCBBAA998801)
|
||||
b := &bytes.Buffer{}
|
||||
Expect(func() { LittleEndian.WriteUint56(b, num) }).Should(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
Context("WriteUint64", func() {
|
||||
It("outputs 8 bytes", func() {
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint64(b, uint64(1))
|
||||
Expect(b.Len()).To(Equal(8))
|
||||
})
|
||||
|
||||
It("outputs a little endian", func() {
|
||||
num := uint64(0xFFEEDDCCBBAA9988)
|
||||
b := &bytes.Buffer{}
|
||||
LittleEndian.WriteUint64(b, num)
|
||||
Expect(b.Bytes()).To(Equal([]byte{0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("ReadUintN", func() {
|
||||
It("reads n bytes", func() {
|
||||
m := map[uint8]uint64{
|
||||
0: 0x0, 1: 0x01, 2: 0x0201, 3: 0x030201, 4: 0x04030201, 5: 0x0504030201,
|
||||
6: 0x060504030201, 7: 0x07060504030201, 8: 0x0807060504030201,
|
||||
}
|
||||
for n, expected := range m {
|
||||
b := bytes.NewReader([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})
|
||||
i, err := LittleEndian.ReadUintN(b, n)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(i).To(Equal(expected))
|
||||
}
|
||||
})
|
||||
|
||||
It("errors", func() {
|
||||
b := bytes.NewReader([]byte{0x1, 0x2})
|
||||
_, err := LittleEndian.ReadUintN(b, 3)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue