From 82508f156275481d710bcc6cf33ddf1e21b850e6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 30 Sep 2018 23:22:04 -0700 Subject: [PATCH] use tls-tris instead of mint --- client.go | 59 +- client_test.go | 35 +- crypto_stream.go | 69 +- crypto_stream_manager.go | 55 + crypto_stream_manager_test.go | 65 + crypto_stream_test.go | 89 +- framer.go | 44 +- framer_test.go | 13 +- h2quic/client.go | 2 +- internal/ackhandler/sent_packet_handler.go | 20 +- .../ackhandler/sent_packet_handler_test.go | 3 +- internal/ackhandler/sent_packet_history.go | 9 +- internal/crypto/hkdf.go | 2 +- internal/crypto/key_derivation.go | 51 - internal/crypto/key_derivation_test.go | 56 - internal/crypto/null_aead_aesgcm.go | 8 +- internal/handshake/aead.go | 58 + internal/handshake/aead_test.go | 51 + internal/handshake/crypto_setup_client.go | 2 +- .../handshake/crypto_setup_client_test.go | 20 +- internal/handshake/crypto_setup_server.go | 2 +- .../handshake/crypto_setup_server_test.go | 36 +- internal/handshake/crypto_setup_tls.go | 530 +++- internal/handshake/crypto_setup_tls_test.go | 398 ++- internal/handshake/crypto_stream_conn.go | 69 - internal/handshake/crypto_stream_conn_test.go | 41 - internal/handshake/interface.go | 36 +- internal/handshake/mock_mint_tls_test.go | 72 - internal/handshake/mockgen.go | 3 - internal/handshake/qtls.go | 48 + internal/handshake/tls_extension.go | 20 - .../handshake/tls_extension_handler_client.go | 68 +- .../tls_extension_handler_client_test.go | 188 +- .../handshake/tls_extension_handler_server.go | 72 +- .../tls_extension_handler_server_test.go | 154 +- internal/handshake/tls_extension_test.go | 28 +- internal/mocks/mockgen.go | 1 - internal/mocks/tls_extension_handler.go | 72 - internal/protocol/server_parameters.go | 4 + internal/protocol/stream_id.go | 2 +- internal/protocol/stream_id_test.go | 4 +- internal/protocol/version.go | 19 +- internal/protocol/version_test.go | 29 +- mint_utils.go | 52 - mint_utils_test.go | 65 - mock_crypto_data_handler.go | 47 + mock_crypto_stream_test.go | 142 +- mock_quic_aead_test.go | 13 + mock_sealing_manager_test.go | 42 +- mockgen.go | 5 +- packet_packer.go | 76 +- packet_packer_legacy.go | 14 +- packet_packer_legacy_test.go | 12 +- packet_packer_test.go | 1257 ++++---- packet_unpacker.go | 21 +- packet_unpacker_test.go | 14 +- quic_suite_test.go | 2 - receive_stream.go | 2 +- send_stream.go | 2 +- server_session.go | 2 +- server_tls.go | 17 +- server_tls_test.go | 16 +- session.go | 164 +- session_test.go | 35 +- streams_map.go | 4 +- streams_map_legacy.go | 7 +- streams_map_legacy_test.go | 7 + streams_map_test.go | 6 +- vendor/github.com/bifurcation/mint/LICENSE.md | 21 - vendor/github.com/bifurcation/mint/README.md | 94 - vendor/github.com/bifurcation/mint/alert.go | 101 - .../bifurcation/mint/client-state-machine.go | 1083 ------- vendor/github.com/bifurcation/mint/common.go | 266 -- vendor/github.com/bifurcation/mint/conn.go | 921 ------ .../bifurcation/mint/cookie-protector.go | 86 - vendor/github.com/bifurcation/mint/crypto.go | 667 ---- vendor/github.com/bifurcation/mint/dtls.go | 222 -- .../github.com/bifurcation/mint/extensions.go | 626 ---- vendor/github.com/bifurcation/mint/ffdhe.go | 147 - .../bifurcation/mint/frame-reader.go | 98 - .../bifurcation/mint/handshake-layer.go | 551 ---- .../bifurcation/mint/handshake-messages.go | 481 --- vendor/github.com/bifurcation/mint/log.go | 55 - vendor/github.com/bifurcation/mint/mint.svg | 101 - .../bifurcation/mint/negotiation.go | 218 -- .../bifurcation/mint/record-layer.go | 458 --- .../bifurcation/mint/server-state-machine.go | 1177 ------- .../bifurcation/mint/state-machine.go | 247 -- vendor/github.com/bifurcation/mint/timer.go | 122 - vendor/github.com/bifurcation/mint/tls.go | 179 -- vendor/github.com/marten-seemann/qtls/13.go | 1162 +++++++ .../github.com/marten-seemann/qtls/README.md | 107 + .../github.com/marten-seemann/qtls/alert.go | 84 + vendor/github.com/marten-seemann/qtls/auth.go | 107 + .../marten-seemann/qtls/cipher_suites.go | 437 +++ .../github.com/marten-seemann/qtls/common.go | 1215 +++++++ vendor/github.com/marten-seemann/qtls/conn.go | 1766 +++++++++++ .../marten-seemann/qtls/handshake_client.go | 1006 ++++++ .../marten-seemann/qtls/handshake_messages.go | 2781 +++++++++++++++++ .../marten-seemann/qtls/handshake_server.go | 943 ++++++ vendor/github.com/marten-seemann/qtls/hkdf.go | 58 + .../marten-seemann/qtls/key_agreement.go | 402 +++ vendor/github.com/marten-seemann/qtls/prf.go | 355 +++ .../marten-seemann/qtls/subcerts.go | 392 +++ .../github.com/marten-seemann/qtls/ticket.go | 326 ++ vendor/github.com/marten-seemann/qtls/tls.go | 297 ++ .../chacha20poly1305/chacha20poly1305.go | 101 + .../chacha20poly1305_amd64.go | 86 + .../chacha20poly1305/chacha20poly1305_amd64.s | 2695 ++++++++++++++++ .../chacha20poly1305_generic.go | 81 + .../chacha20poly1305_noasm.go | 15 + .../chacha20poly1305/xchacha20poly1305.go | 104 + .../internal/chacha20/chacha_generic.go | 264 ++ .../crypto/internal/chacha20/chacha_noasm.go | 16 + .../crypto/internal/chacha20/chacha_s390x.go | 30 + .../x/crypto/internal/chacha20/chacha_s390x.s | 283 ++ .../x/crypto/internal/chacha20/xor.go | 43 + .../x/crypto/internal/subtle/aliasing.go | 32 + .../internal/subtle/aliasing_appengine.go | 35 + .../golang.org/x/crypto/poly1305/poly1305.go | 33 + .../golang.org/x/crypto/poly1305/sum_amd64.go | 22 + .../golang.org/x/crypto/poly1305/sum_amd64.s | 125 + .../golang.org/x/crypto/poly1305/sum_arm.go | 22 + vendor/golang.org/x/crypto/poly1305/sum_arm.s | 427 +++ .../golang.org/x/crypto/poly1305/sum_noasm.go | 14 + .../golang.org/x/crypto/poly1305/sum_ref.go | 139 + .../golang.org/x/crypto/poly1305/sum_s390x.go | 49 + .../golang.org/x/crypto/poly1305/sum_s390x.s | 400 +++ .../x/crypto/poly1305/sum_vmsl_s390x.s | 931 ++++++ vendor/golang.org/x/sys/LICENSE | 27 + vendor/golang.org/x/sys/PATENTS | 22 + vendor/golang.org/x/sys/cpu/cpu.go | 38 + vendor/golang.org/x/sys/cpu/cpu_arm.go | 7 + vendor/golang.org/x/sys/cpu/cpu_arm64.go | 7 + vendor/golang.org/x/sys/cpu/cpu_gc_x86.go | 16 + vendor/golang.org/x/sys/cpu/cpu_gccgo.c | 43 + vendor/golang.org/x/sys/cpu/cpu_gccgo.go | 26 + vendor/golang.org/x/sys/cpu/cpu_mips64x.go | 9 + vendor/golang.org/x/sys/cpu/cpu_mipsx.go | 9 + vendor/golang.org/x/sys/cpu/cpu_ppc64x.go | 9 + vendor/golang.org/x/sys/cpu/cpu_s390x.go | 7 + vendor/golang.org/x/sys/cpu/cpu_x86.go | 55 + vendor/golang.org/x/sys/cpu/cpu_x86.s | 27 + vendor/vendor.json | 42 +- 144 files changed, 20124 insertions(+), 10157 deletions(-) create mode 100644 crypto_stream_manager.go create mode 100644 crypto_stream_manager_test.go delete mode 100644 internal/crypto/key_derivation.go delete mode 100644 internal/crypto/key_derivation_test.go create mode 100644 internal/handshake/aead.go create mode 100644 internal/handshake/aead_test.go delete mode 100644 internal/handshake/crypto_stream_conn.go delete mode 100644 internal/handshake/crypto_stream_conn_test.go delete mode 100644 internal/handshake/mock_mint_tls_test.go delete mode 100644 internal/handshake/mockgen.go create mode 100644 internal/handshake/qtls.go delete mode 100644 internal/mocks/tls_extension_handler.go delete mode 100644 mint_utils.go delete mode 100644 mint_utils_test.go create mode 100644 mock_crypto_data_handler.go delete mode 100644 vendor/github.com/bifurcation/mint/LICENSE.md delete mode 100644 vendor/github.com/bifurcation/mint/README.md delete mode 100644 vendor/github.com/bifurcation/mint/alert.go delete mode 100644 vendor/github.com/bifurcation/mint/client-state-machine.go delete mode 100644 vendor/github.com/bifurcation/mint/common.go delete mode 100644 vendor/github.com/bifurcation/mint/conn.go delete mode 100644 vendor/github.com/bifurcation/mint/cookie-protector.go delete mode 100644 vendor/github.com/bifurcation/mint/crypto.go delete mode 100644 vendor/github.com/bifurcation/mint/dtls.go delete mode 100644 vendor/github.com/bifurcation/mint/extensions.go delete mode 100644 vendor/github.com/bifurcation/mint/ffdhe.go delete mode 100644 vendor/github.com/bifurcation/mint/frame-reader.go delete mode 100644 vendor/github.com/bifurcation/mint/handshake-layer.go delete mode 100644 vendor/github.com/bifurcation/mint/handshake-messages.go delete mode 100644 vendor/github.com/bifurcation/mint/log.go delete mode 100644 vendor/github.com/bifurcation/mint/mint.svg delete mode 100644 vendor/github.com/bifurcation/mint/negotiation.go delete mode 100644 vendor/github.com/bifurcation/mint/record-layer.go delete mode 100644 vendor/github.com/bifurcation/mint/server-state-machine.go delete mode 100644 vendor/github.com/bifurcation/mint/state-machine.go delete mode 100644 vendor/github.com/bifurcation/mint/timer.go delete mode 100644 vendor/github.com/bifurcation/mint/tls.go create mode 100644 vendor/github.com/marten-seemann/qtls/13.go create mode 100644 vendor/github.com/marten-seemann/qtls/README.md create mode 100644 vendor/github.com/marten-seemann/qtls/alert.go create mode 100644 vendor/github.com/marten-seemann/qtls/auth.go create mode 100644 vendor/github.com/marten-seemann/qtls/cipher_suites.go create mode 100644 vendor/github.com/marten-seemann/qtls/common.go create mode 100644 vendor/github.com/marten-seemann/qtls/conn.go create mode 100644 vendor/github.com/marten-seemann/qtls/handshake_client.go create mode 100644 vendor/github.com/marten-seemann/qtls/handshake_messages.go create mode 100644 vendor/github.com/marten-seemann/qtls/handshake_server.go create mode 100644 vendor/github.com/marten-seemann/qtls/hkdf.go create mode 100644 vendor/github.com/marten-seemann/qtls/key_agreement.go create mode 100644 vendor/github.com/marten-seemann/qtls/prf.go create mode 100644 vendor/github.com/marten-seemann/qtls/subcerts.go create mode 100644 vendor/github.com/marten-seemann/qtls/ticket.go create mode 100644 vendor/github.com/marten-seemann/qtls/tls.go create mode 100644 vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305.go create mode 100644 vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_amd64.go create mode 100644 vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_amd64.s create mode 100644 vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_generic.go create mode 100644 vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_noasm.go create mode 100644 vendor/golang.org/x/crypto/chacha20poly1305/xchacha20poly1305.go create mode 100644 vendor/golang.org/x/crypto/internal/chacha20/chacha_generic.go create mode 100644 vendor/golang.org/x/crypto/internal/chacha20/chacha_noasm.go create mode 100644 vendor/golang.org/x/crypto/internal/chacha20/chacha_s390x.go create mode 100644 vendor/golang.org/x/crypto/internal/chacha20/chacha_s390x.s create mode 100644 vendor/golang.org/x/crypto/internal/chacha20/xor.go create mode 100644 vendor/golang.org/x/crypto/internal/subtle/aliasing.go create mode 100644 vendor/golang.org/x/crypto/internal/subtle/aliasing_appengine.go create mode 100644 vendor/golang.org/x/crypto/poly1305/poly1305.go create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_amd64.go create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_amd64.s create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_arm.go create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_arm.s create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_noasm.go create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_ref.go create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_s390x.go create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_s390x.s create mode 100644 vendor/golang.org/x/crypto/poly1305/sum_vmsl_s390x.s create mode 100644 vendor/golang.org/x/sys/LICENSE create mode 100644 vendor/golang.org/x/sys/PATENTS create mode 100644 vendor/golang.org/x/sys/cpu/cpu.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_arm.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_arm64.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_gc_x86.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_gccgo.c create mode 100644 vendor/golang.org/x/sys/cpu/cpu_gccgo.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_mips64x.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_mipsx.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_ppc64x.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_s390x.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_x86.go create mode 100644 vendor/golang.org/x/sys/cpu/cpu_x86.s diff --git a/client.go b/client.go index 280c3fcd..6ec86ef3 100644 --- a/client.go +++ b/client.go @@ -9,7 +9,6 @@ import ( "net" "sync" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -33,9 +32,8 @@ type client struct { receivedVersionNegotiationPacket bool negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet - tlsConf *tls.Config - mintConf *mint.Config - config *Config + tlsConf *tls.Config + config *Config srcConnID protocol.ConnectionID destConnID protocol.ConnectionID @@ -304,27 +302,10 @@ func (c *client) dialGQUIC(ctx context.Context) error { } func (c *client) dialTLS(ctx context.Context) error { - params := &handshake.TransportParameters{ - StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, - ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, - IdleTimeout: c.config.IdleTimeout, - OmitConnectionID: c.config.RequestConnectionIDOmission, - MaxBidiStreams: uint16(c.config.MaxIncomingStreams), - MaxUniStreams: uint16(c.config.MaxIncomingUniStreams), - DisableMigration: true, - } - extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger) - mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient) - if err != nil { + if err := c.createNewTLSSession(c.version); err != nil { return err } - mintConf.ExtensionHandler = extHandler - c.mintConf = mintConf - - if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { - return err - } - err = c.establishSecureConnection(ctx) + err := c.establishSecureConnection(ctx) if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion { return c.dial(ctx) } @@ -401,15 +382,9 @@ func (c *client) handlePacketImpl(p *receivedPacket) error { } } - if p.header.IsLongHeader { - switch p.header.Type { - case protocol.PacketTypeRetry: - c.handleRetryPacket(p.header) - return nil - case protocol.PacketTypeHandshake, protocol.PacketType0RTT: - default: - return fmt.Errorf("Received unsupported packet type: %s", p.header.Type) - } + if p.header.Type == protocol.PacketTypeRetry { + c.handleRetryPacket(p.header) + return nil } // this is the first packet we are receiving @@ -526,10 +501,17 @@ func (c *client) createNewGQUICSession() error { return nil } -func (c *client) createNewTLSSession( - paramsChan <-chan handshake.TransportParameters, - version protocol.VersionNumber, -) error { +func (c *client) createNewTLSSession(version protocol.VersionNumber) error { + params := &handshake.TransportParameters{ + StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, + ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, + IdleTimeout: c.config.IdleTimeout, + OmitConnectionID: c.config.RequestConnectionIDOmission, + MaxBidiStreams: uint16(c.config.MaxIncomingStreams), + MaxUniStreams: uint16(c.config.MaxIncomingUniStreams), + DisableMigration: true, + } + c.mutex.Lock() defer c.mutex.Unlock() runner := &runner{ @@ -543,8 +525,9 @@ func (c *client) createNewTLSSession( c.destConnID, c.srcConnID, c.config, - c.mintConf, - paramsChan, + c.tlsConf, + params, + c.initialVersion, 1, c.logger, c.version, diff --git a/client_test.go b/client_test.go index bf3fff08..ed89d121 100644 --- a/client_test.go +++ b/client_test.go @@ -10,7 +10,6 @@ import ( "os" "time" - "github.com/bifurcation/mint" "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -524,8 +523,9 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, configP *Config, - _ *mint.Config, - paramsChan <-chan handshake.TransportParameters, + _ *tls.Config, + params *handshake.TransportParameters, + _ protocol.VersionNumber, /* initial version */ _ protocol.PacketNumber, _ utils.Logger, versionP protocol.VersionNumber, @@ -585,8 +585,9 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, - _ *mint.Config, - _ <-chan handshake.TransportParameters, + _ *tls.Config, + _ *handshake.TransportParameters, + _ protocol.VersionNumber, /* initial version */ _ protocol.PacketNumber, _ utils.Logger, _ protocol.VersionNumber, @@ -644,8 +645,9 @@ var _ = Describe("Client", func() { _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, - _ *mint.Config, - _ <-chan handshake.TransportParameters, + _ *tls.Config, + _ *handshake.TransportParameters, + _ protocol.VersionNumber, /* initial version */ _ protocol.PacketNumber, _ utils.Logger, _ protocol.VersionNumber, @@ -861,25 +863,6 @@ var _ = Describe("Client", func() { Expect(cl.GetVersion()).To(Equal(cl.version)) }) - It("ignores packets with the wrong Long Header Type", func() { - cl.config = &Config{} - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - PayloadLen: 123, - SrcConnectionID: connID, - DestConnectionID: connID, - PacketNumberLen: protocol.PacketNumberLen1, - Version: versionIETFFrames, - } - err := cl.handlePacketImpl(&receivedPacket{ - remoteAddr: addr, - header: hdr, - data: make([]byte, 456), - }) - Expect(err).To(MatchError("Received unsupported packet type: Initial")) - }) - It("ignores packets without connection id, if it didn't request connection id trunctation", func() { cl.version = versionGQUICFrames cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls diff --git a/crypto_stream.go b/crypto_stream.go index d51dc2ab..87933357 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -1,42 +1,65 @@ package quic import ( + "fmt" "io" - "github.com/lucas-clemente/quic-go/internal/flowcontrol" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) type cryptoStream interface { - StreamID() protocol.StreamID - io.Reader + // for receiving data + HandleCryptoFrame(*wire.CryptoFrame) error + GetCryptoData() []byte + // for sending data io.Writer - handleStreamFrame(*wire.StreamFrame) error - hasData() bool - popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool) - closeForShutdown(error) - setReadOffset(protocol.ByteCount) - // methods needed for flow control - getWindowUpdate() protocol.ByteCount - handleMaxStreamDataFrame(*wire.MaxStreamDataFrame) + HasData() bool + PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame } type cryptoStreamImpl struct { - *stream + queue *frameSorter + + writeOffset protocol.ByteCount + writeBuf []byte } -var _ cryptoStream = &cryptoStreamImpl{} - -func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream { - str := newStream(version.CryptoStreamID(), sender, flowController, version) - return &cryptoStreamImpl{str} +func newCryptoStream() cryptoStream { + return &cryptoStreamImpl{ + queue: newFrameSorter(), + } } -// SetReadOffset sets the read offset. -// It is only needed for the crypto stream. -// It must not be called concurrently with any other stream methods, especially Read and Write. -func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) { - s.receiveStream.readOffset = offset - s.receiveStream.frameQueue.readPos = offset +func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { + if maxOffset := f.Offset + protocol.ByteCount(len(f.Data)); maxOffset > protocol.MaxCryptoStreamOffset { + return fmt.Errorf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset) + } + return s.queue.Push(f.Data, f.Offset, false) +} + +// GetCryptoData retrieves data that was received in CRYPTO frames +func (s *cryptoStreamImpl) GetCryptoData() []byte { + data, _ := s.queue.Pop() + return data +} + +// Writes writes data that should be sent out in CRYPTO frames +func (s *cryptoStreamImpl) Write(p []byte) (int, error) { + s.writeBuf = append(s.writeBuf, p...) + return len(p), nil +} + +func (s *cryptoStreamImpl) HasData() bool { + return len(s.writeBuf) > 0 +} + +func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { + f := &wire.CryptoFrame{Offset: s.writeOffset} + n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) + f.Data = s.writeBuf[:n] + s.writeBuf = s.writeBuf[n:] + s.writeOffset += n + return f } diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go new file mode 100644 index 00000000..747b280a --- /dev/null +++ b/crypto_stream_manager.go @@ -0,0 +1,55 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type cryptoDataHandler interface { + HandleData([]byte, protocol.EncryptionLevel) error +} + +type cryptoStreamManager struct { + cryptoHandler cryptoDataHandler + + initialStream cryptoStream + handshakeStream cryptoStream +} + +func newCryptoStreamManager( + cryptoHandler cryptoDataHandler, + initialStream cryptoStream, + handshakeStream cryptoStream, +) *cryptoStreamManager { + return &cryptoStreamManager{ + cryptoHandler: cryptoHandler, + initialStream: initialStream, + handshakeStream: handshakeStream, + } +} + +func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { + var str cryptoStream + switch encLevel { + case protocol.EncryptionInitial: + str = m.initialStream + case protocol.EncryptionHandshake: + str = m.handshakeStream + default: + return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) + } + if err := str.HandleCryptoFrame(frame); err != nil { + return err + } + for { + data := str.GetCryptoData() + if data == nil { + return nil + } + if err := m.cryptoHandler.HandleData(data, encLevel); err != nil { + return err + } + } +} diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go new file mode 100644 index 00000000..bc281505 --- /dev/null +++ b/crypto_stream_manager_test.go @@ -0,0 +1,65 @@ +package quic + +import ( + "errors" + + "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Crypto Stream Manager", func() { + var ( + csm *cryptoStreamManager + cs *MockCryptoDataHandler + ) + + BeforeEach(func() { + initialStream := newCryptoStream() + handshakeStream := newCryptoStream() + cs = NewMockCryptoDataHandler(mockCtrl) + csm = newCryptoStreamManager(cs, initialStream, handshakeStream) + }) + + It("handles in in-order crypto frame", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionInitial) + Expect(csm.HandleCryptoFrame(f, protocol.EncryptionInitial)).To(Succeed()) + }) + + It("errors for unknown encryption levels", func() { + err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption1RTT) + Expect(err).To(MatchError("received CRYPTO frame with unexpected encryption level: 1-RTT")) + }) + + It("handles out-of-order crypto frames", func() { + f1 := &wire.CryptoFrame{Data: []byte("foo")} + f2 := &wire.CryptoFrame{ + Offset: 3, + Data: []byte("bar"), + } + gomock.InOrder( + cs.EXPECT().HandleData([]byte("foo"), protocol.EncryptionInitial), + cs.EXPECT().HandleData([]byte("bar"), protocol.EncryptionInitial), + ) + Expect(csm.HandleCryptoFrame(f1, protocol.EncryptionInitial)).To(Succeed()) + Expect(csm.HandleCryptoFrame(f2, protocol.EncryptionInitial)).To(Succeed()) + }) + + It("handles handshake data", func() { + f := &wire.CryptoFrame{Data: []byte("foobar")} + cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake) + Expect(csm.HandleCryptoFrame(f, protocol.EncryptionHandshake)).To(Succeed()) + }) + + It("returns the error if handling crypto data fails", func() { + testErr := errors.New("test error") + f := &wire.CryptoFrame{Data: []byte("foobar")} + cs.EXPECT().HandleData([]byte("foobar"), protocol.EncryptionHandshake).Return(testErr) + err := csm.HandleCryptoFrame(f, protocol.EncryptionHandshake) + Expect(err).To(MatchError(testErr)) + }) +}) diff --git a/crypto_stream_test.go b/crypto_stream_test.go index ce97e94a..514930db 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,7 +1,10 @@ package quic import ( + "fmt" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -9,18 +12,88 @@ import ( var _ = Describe("Crypto Stream", func() { var ( - str *cryptoStreamImpl - mockSender *MockStreamSender + str cryptoStream ) BeforeEach(func() { - mockSender = NewMockStreamSender(mockCtrl) - str = newCryptoStream(mockSender, nil, protocol.VersionWhatever).(*cryptoStreamImpl) + str = newCryptoStream() }) - It("sets the read offset", func() { - str.setReadOffset(0x42) - Expect(str.receiveStream.readOffset).To(Equal(protocol.ByteCount(0x42))) - Expect(str.receiveStream.frameQueue.readPos).To(Equal(protocol.ByteCount(0x42))) + Context("handling incoming data", func() { + It("handles in-order CRYPTO frames", func() { + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: []byte("foobar"), + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal([]byte("foobar"))) + Expect(str.GetCryptoData()).To(BeNil()) + }) + + It("errors if the frame exceeds the maximum offset", func() { + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: protocol.MaxCryptoStreamOffset - 5, + Data: []byte("foobar"), + }) + Expect(err).To(MatchError(fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", protocol.MaxCryptoStreamOffset+1, protocol.MaxCryptoStreamOffset))) + }) + + It("handles out-of-order CRYPTO frames", func() { + err := str.HandleCryptoFrame(&wire.CryptoFrame{ + Offset: 3, + Data: []byte("bar"), + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(BeNil()) + err = str.HandleCryptoFrame(&wire.CryptoFrame{ + Data: []byte("foo"), + }) + Expect(err).ToNot(HaveOccurred()) + Expect(str.GetCryptoData()).To(Equal([]byte("foo"))) + Expect(str.GetCryptoData()).To(Equal([]byte("bar"))) + Expect(str.GetCryptoData()).To(BeNil()) + }) + }) + + Context("writing data", func() { + It("says if it has data", func() { + Expect(str.HasData()).To(BeFalse()) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.HasData()).To(BeTrue()) + }) + + It("pops crypto frames", func() { + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + f := str.PopCryptoFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(f.Offset).To(BeZero()) + Expect(f.Data).To(Equal([]byte("foobar"))) + }) + + It("coalesces multiple writes", func() { + _, err := str.Write([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("bar")) + Expect(err).ToNot(HaveOccurred()) + f := str.PopCryptoFrame(1000) + Expect(f).ToNot(BeNil()) + Expect(f.Offset).To(BeZero()) + Expect(f.Data).To(Equal([]byte("foobar"))) + }) + + It("respects the maximum size", func() { + frameHeaderLen := (&wire.CryptoFrame{}).Length(protocol.VersionWhatever) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + f := str.PopCryptoFrame(frameHeaderLen + 3) + Expect(f).ToNot(BeNil()) + Expect(f.Offset).To(BeZero()) + Expect(f.Data).To(Equal([]byte("foo"))) + f = str.PopCryptoFrame(frameHeaderLen + 3) + Expect(f).ToNot(BeNil()) + Expect(f.Offset).To(Equal(protocol.ByteCount(3))) + Expect(f.Data).To(Equal([]byte("bar"))) + }) }) }) diff --git a/framer.go b/framer.go index 74ca8c45..fbfe9bb7 100644 --- a/framer.go +++ b/framer.go @@ -7,39 +7,47 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type framer struct { +type framer interface { + QueueControlFrame(wire.Frame) + AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) + + AddActiveStream(protocol.StreamID) + AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame +} + +type framerI struct { + mutex sync.Mutex + streamGetter streamGetter - cryptoStream cryptoStream version protocol.VersionNumber - streamQueueMutex sync.Mutex - activeStreams map[protocol.StreamID]struct{} - streamQueue []protocol.StreamID + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID controlFrameMutex sync.Mutex controlFrames []wire.Frame } +var _ framer = &framerI{} + func newFramer( - cryptoStream cryptoStream, streamGetter streamGetter, v protocol.VersionNumber, -) *framer { - return &framer{ +) framer { + return &framerI{ streamGetter: streamGetter, - cryptoStream: cryptoStream, activeStreams: make(map[protocol.StreamID]struct{}), version: v, } } -func (f *framer) QueueControlFrame(frame wire.Frame) { +func (f *framerI) QueueControlFrame(frame wire.Frame) { f.controlFrameMutex.Lock() f.controlFrames = append(f.controlFrames, frame) f.controlFrameMutex.Unlock() } -func (f *framer) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { +func (f *framerI) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { var length protocol.ByteCount f.controlFrameMutex.Lock() for len(f.controlFrames) > 0 { @@ -56,20 +64,18 @@ func (f *framer) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCo return frames, length } -// AddActiveStream adds a stream that has data to write. -// It should not be used for the crypto stream. -func (f *framer) AddActiveStream(id protocol.StreamID) { - f.streamQueueMutex.Lock() +func (f *framerI) AddActiveStream(id protocol.StreamID) { + f.mutex.Lock() if _, ok := f.activeStreams[id]; !ok { f.streamQueue = append(f.streamQueue, id) f.activeStreams[id] = struct{}{} } - f.streamQueueMutex.Unlock() + f.mutex.Unlock() } -func (f *framer) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { +func (f *framerI) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { var length protocol.ByteCount - f.streamQueueMutex.Lock() + f.mutex.Lock() // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet numActiveStreams := len(f.streamQueue) for i := 0; i < numActiveStreams; i++ { @@ -98,6 +104,6 @@ func (f *framer) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCou frames = append(frames, frame) length += frame.Length(f.version) } - f.streamQueueMutex.Unlock() + f.mutex.Unlock() return frames } diff --git a/framer_test.go b/framer_test.go index a0cdcb31..c8c33a5a 100644 --- a/framer_test.go +++ b/framer_test.go @@ -18,10 +18,10 @@ var _ = Describe("Stream Framer", func() { ) var ( - framer *framer - cryptoStream *MockCryptoStream + framer framer stream1, stream2 *MockSendStreamI streamGetter *MockStreamGetter + version protocol.VersionNumber ) BeforeEach(func() { @@ -30,8 +30,7 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().StreamID().Return(protocol.StreamID(5)).AnyTimes() stream2 = NewMockSendStreamI(mockCtrl) stream2.EXPECT().StreamID().Return(protocol.StreamID(6)).AnyTimes() - cryptoStream = NewMockCryptoStream(mockCtrl) - framer = newFramer(cryptoStream, streamGetter, versionGQUICFrames) + framer = newFramer(streamGetter, version) }) Context("handling control frames", func() { @@ -43,7 +42,7 @@ var _ = Describe("Stream Framer", func() { frames, length := framer.AppendControlFrames(nil, 1000) Expect(frames).To(ContainElement(mdf)) Expect(frames).To(ContainElement(msdf)) - Expect(length).To(Equal(mdf.Length(framer.version) + msdf.Length(framer.version))) + Expect(length).To(Equal(mdf.Length(version) + msdf.Length(version))) }) It("appends to the slice given", func() { @@ -52,13 +51,13 @@ var _ = Describe("Stream Framer", func() { framer.QueueControlFrame(mdf) frames, length := framer.AppendControlFrames([]wire.Frame{ack}, 1000) Expect(frames).To(Equal([]wire.Frame{ack, mdf})) - Expect(length).To(Equal(mdf.Length(framer.version))) + Expect(length).To(Equal(mdf.Length(version))) }) It("adds the right number of frames", func() { maxSize := protocol.ByteCount(1000) bf := &wire.BlockedFrame{Offset: 0x1337} - bfLen := bf.Length(framer.version) + bfLen := bf.Length(version) numFrames := int(maxSize / bfLen) // max number of frames that fit into maxSize for i := 0; i < numFrames+1; i++ { framer.QueueControlFrame(bf) diff --git a/h2quic/client.go b/h2quic/client.go index ac28a7f0..dddc0d23 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -94,7 +94,7 @@ func (c *client) dial() error { } // once the version has been negotiated, open the header stream - c.headerStream, err = c.session.OpenStream() + c.headerStream, err = c.session.OpenStreamSync() if err != nil { return err } diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 58360437..7ff35c29 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -110,13 +110,15 @@ func (h *sentPacketHandler) SetHandshakeComplete() { h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.") var queue []*Packet for _, packet := range h.retransmissionQueue { - if packet.EncryptionLevel == protocol.EncryptionForwardSecure { + if packet.EncryptionLevel == protocol.EncryptionForwardSecure || + packet.EncryptionLevel == protocol.Encryption1RTT { queue = append(queue, packet) } } var handshakePackets []*Packet h.packetHistory.Iterate(func(p *Packet) (bool, error) { - if p.EncryptionLevel != protocol.EncryptionForwardSecure { + if p.EncryptionLevel != protocol.EncryptionForwardSecure && + p.EncryptionLevel != protocol.Encryption1RTT { handshakePackets = append(handshakePackets, p) } return true, nil @@ -167,7 +169,8 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmitt isRetransmittable := len(packet.Frames) != 0 if isRetransmittable { - if packet.EncryptionLevel < protocol.EncryptionForwardSecure { + if packet.EncryptionLevel != protocol.EncryptionForwardSecure && + packet.EncryptionLevel != protocol.Encryption1RTT { h.lastSentHandshakePacketTime = packet.SendTime } h.lastSentRetransmittablePacketTime = packet.SendTime @@ -214,8 +217,11 @@ func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumbe priorInFlight := h.bytesInFlight for _, p := range ackedPackets { - if encLevel < p.EncryptionLevel { - return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel) + // TODO(#1534): also check the encryption level for IETF QUIC + if !h.version.UsesTLS() { + if encLevel < p.EncryptionLevel { + return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel) + } } // largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0 // It is safe to ignore the corner case of packets that just acked packet 0, because @@ -586,7 +592,9 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error { var handshakePackets []*Packet h.packetHistory.Iterate(func(p *Packet) (bool, error) { - if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure { + if p.canBeRetransmitted && + p.EncryptionLevel != protocol.EncryptionForwardSecure && + p.EncryptionLevel != protocol.Encryption1RTT { handshakePackets = append(handshakePackets, p) } return true, nil diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index 9960e537..32a1def7 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -1051,7 +1051,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.GetAlarmTimeout().Sub(lastHandshakePacketSendTime)).To(Equal(4 * time.Minute)) }) - It("rejects an ACK that acks packets with a higher encryption level", func() { + // TODO(#1534): also check the encryption level for IETF QUIC + PIt("rejects an ACK that acks packets with a higher encryption level", func() { handler.SentPacket(&Packet{ PacketNumber: 13, EncryptionLevel: protocol.EncryptionForwardSecure, diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index 91aa2697..cf372ef5 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -35,7 +35,8 @@ func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement { } if p.canBeRetransmitted { h.numOutstandingPackets++ - if p.EncryptionLevel < protocol.EncryptionForwardSecure { + if p.EncryptionLevel != protocol.EncryptionForwardSecure && + p.EncryptionLevel != protocol.Encryption1RTT { h.numOutstandingHandshakePackets++ } } @@ -106,7 +107,8 @@ func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) if h.numOutstandingPackets < 0 { panic("numOutstandingHandshakePackets negative") } - if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { + if el.Value.EncryptionLevel != protocol.EncryptionForwardSecure && + el.Value.EncryptionLevel != protocol.Encryption1RTT { h.numOutstandingHandshakePackets-- if h.numOutstandingHandshakePackets < 0 { panic("numOutstandingHandshakePackets negative") @@ -147,7 +149,8 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { if h.numOutstandingPackets < 0 { panic("numOutstandingHandshakePackets negative") } - if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure { + if el.Value.EncryptionLevel != protocol.EncryptionForwardSecure && + el.Value.EncryptionLevel != protocol.Encryption1RTT { h.numOutstandingHandshakePackets-- if h.numOutstandingHandshakePackets < 0 { panic("numOutstandingHandshakePackets negative") diff --git a/internal/crypto/hkdf.go b/internal/crypto/hkdf.go index 2501aa79..06228938 100644 --- a/internal/crypto/hkdf.go +++ b/internal/crypto/hkdf.go @@ -48,7 +48,7 @@ func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte { } // hkdfExpandLabel HKDF expands a label -func hkdfExpandLabel(hash crypto.Hash, secret []byte, label string, length int) []byte { +func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, length int) []byte { const prefix = "quic " qlabel := make([]byte, 2 /* length */ +1 /* length of label */ +len(prefix)+len(label)+1 /* length of context (empty) */) binary.BigEndian.PutUint16(qlabel[0:2], uint16(length)) diff --git a/internal/crypto/key_derivation.go b/internal/crypto/key_derivation.go deleted file mode 100644 index d635b12b..00000000 --- a/internal/crypto/key_derivation.go +++ /dev/null @@ -1,51 +0,0 @@ -package crypto - -import ( - "crypto" - - "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -const ( - clientExporterLabel = "EXPORTER-QUIC client 1rtt" - serverExporterLabel = "EXPORTER-QUIC server 1rtt" -) - -// A TLSExporter gets the negotiated ciphersuite and computes exporter -type TLSExporter interface { - ConnectionState() mint.ConnectionState - ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) -} - -// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance -func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) { - var myLabel, otherLabel string - if pers == protocol.PerspectiveClient { - myLabel = clientExporterLabel - otherLabel = serverExporterLabel - } else { - myLabel = serverExporterLabel - otherLabel = clientExporterLabel - } - myKey, myIV, err := computeKeyAndIV(tls, myLabel) - if err != nil { - return nil, err - } - otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel) - if err != nil { - return nil, err - } - return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) -} - -func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) { - cs := tls.ConnectionState().CipherSuite - secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size()) - if err != nil { - return nil, nil, err - } - key = hkdfExpand(crypto.SHA256, secret, []byte("key"), cs.KeyLen) - iv = hkdfExpand(crypto.SHA256, secret, []byte("iv"), cs.IvLen) - return key, iv, nil -} diff --git a/internal/crypto/key_derivation_test.go b/internal/crypto/key_derivation_test.go deleted file mode 100644 index 5b530ff8..00000000 --- a/internal/crypto/key_derivation_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package crypto - -import ( - "crypto" - "errors" - - "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/protocol" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -type mockTLSExporter struct { - hash crypto.Hash - computerError error -} - -var _ TLSExporter = &mockTLSExporter{} - -func (c *mockTLSExporter) Handshake() mint.Alert { panic("not implemented") } - -func (c *mockTLSExporter) ConnectionState() mint.ConnectionState { - return mint.ConnectionState{ - CipherSuite: mint.CipherSuiteParams{ - Hash: c.hash, - KeyLen: 32, - IvLen: 12, - }, - } -} - -func (c *mockTLSExporter) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - if c.computerError != nil { - return nil, c.computerError - } - return append([]byte(label), context...), nil -} - -var _ = Describe("Key Derivation", func() { - It("derives keys", func() { - clientAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - serverAEAD, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256}, protocol.PerspectiveServer) - Expect(err).ToNot(HaveOccurred()) - ciphertext := clientAEAD.Seal(nil, []byte("foobar"), 0, []byte("aad")) - data, err := serverAEAD.Open(nil, ciphertext, 0, []byte("aad")) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foobar"))) - }) - - It("fails when computing the exporter fails", func() { - testErr := errors.New("test error") - _, err := DeriveAESKeys(&mockTLSExporter{hash: crypto.SHA256, computerError: testErr}, protocol.PerspectiveClient) - Expect(err).To(MatchError(testErr)) - }) -}) diff --git a/internal/crypto/null_aead_aesgcm.go b/internal/crypto/null_aead_aesgcm.go index 48dce27e..17dad016 100644 --- a/internal/crypto/null_aead_aesgcm.go +++ b/internal/crypto/null_aead_aesgcm.go @@ -28,13 +28,13 @@ func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspec func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) { initialSecret := hkdfExtract(crypto.SHA256, connID, quicVersion1Salt) - clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, "client in", crypto.SHA256.Size()) - serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, "server in", crypto.SHA256.Size()) + clientSecret = HkdfExpandLabel(crypto.SHA256, initialSecret, "client in", crypto.SHA256.Size()) + serverSecret = HkdfExpandLabel(crypto.SHA256, initialSecret, "server in", crypto.SHA256.Size()) return } func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) { - key = hkdfExpandLabel(crypto.SHA256, secret, "key", 16) - iv = hkdfExpandLabel(crypto.SHA256, secret, "iv", 12) + key = HkdfExpandLabel(crypto.SHA256, secret, "key", 16) + iv = HkdfExpandLabel(crypto.SHA256, secret, "iv", 12) return } diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go new file mode 100644 index 00000000..21d61a8f --- /dev/null +++ b/internal/handshake/aead.go @@ -0,0 +1,58 @@ +package handshake + +import ( + "crypto/cipher" + "encoding/binary" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type sealer struct { + iv []byte + aead cipher.AEAD + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var _ Sealer = &sealer{} + +func newSealer(aead cipher.AEAD, iv []byte) Sealer { + return &sealer{ + iv: iv, + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), + } +} + +func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) + return s.aead.Seal(dst, s.nonceBuf, src, ad) +} + +func (s *sealer) Overhead() int { + return s.aead.Overhead() +} + +type opener struct { + iv []byte + aead cipher.AEAD + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var _ Opener = &opener{} + +func newOpener(aead cipher.AEAD, iv []byte) Opener { + return &opener{ + iv: iv, + aead: aead, + nonceBuf: make([]byte, aead.NonceSize()), + } +} + +func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) + return o.aead.Open(dst, o.nonceBuf, src, ad) +} diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go new file mode 100644 index 00000000..3556630c --- /dev/null +++ b/internal/handshake/aead_test.go @@ -0,0 +1,51 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("AEAD", func() { + var sealer Sealer + var opener Opener + + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad := []byte("Donec in velit neque.") + + BeforeEach(func() { + key := make([]byte, 16) + rand.Read(key) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + + iv := make([]byte, 12) + rand.Read(iv) + sealer = newSealer(aead, iv) + opener = newOpener(aead, iv) + }) + + It("encrypts and decrypts a message", func() { + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + opened, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) + Expect(err).To(MatchError("cipher: message authentication failed")) + }) + + It("fails to open a message if the packet number is not the same", func() { + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x42, ad) + Expect(err).To(MatchError("cipher: message authentication failed")) + }) +}) diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index 3443161b..8edf5b7c 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -108,7 +108,7 @@ func NewCryptoSetupClient( return cs, nil } -func (h *cryptoSetupClient) HandleCryptoStream() error { +func (h *cryptoSetupClient) RunHandshake() error { messageChan := make(chan HandshakeMessage) errorChan := make(chan error, 1) diff --git a/internal/handshake/crypto_setup_client_test.go b/internal/handshake/crypto_setup_client_test.go index aceb3286..d8f51dde 100644 --- a/internal/handshake/crypto_setup_client_test.go +++ b/internal/handshake/crypto_setup_client_test.go @@ -152,13 +152,13 @@ var _ = Describe("Client Crypto Setup", func() { It("rejects handshake messages with the wrong message tag", func() { HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) }) It("errors on invalid handshake messages", func() { stream.dataToRead.Write([]byte("invalid message")) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeFailed)) }) @@ -170,7 +170,7 @@ var _ = Describe("Client Crypto Setup", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() @@ -453,7 +453,7 @@ var _ = Describe("Client Crypto Setup", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() @@ -470,7 +470,7 @@ var _ = Describe("Client Crypto Setup", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() @@ -703,7 +703,7 @@ var _ = Describe("Client Crypto Setup", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() @@ -722,7 +722,7 @@ var _ = Describe("Client Crypto Setup", func() { cs.serverVerified = true go func() { defer GinkgoRecover() - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() @@ -922,7 +922,7 @@ var _ = Describe("Client Crypto Setup", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() @@ -938,7 +938,7 @@ var _ = Describe("Client Crypto Setup", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() @@ -955,7 +955,7 @@ var _ = Describe("Client Crypto Setup", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.HandshakeFailed, errMockStreamClosing.Error()))) close(done) }() diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index b7a46480..f91bb4f8 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -107,7 +107,7 @@ func NewCryptoSetup( } // HandleCryptoStream reads and writes messages on the crypto stream -func (h *cryptoSetupServer) HandleCryptoStream() error { +func (h *cryptoSetupServer) RunHandshake() error { for { var chloData bytes.Buffer message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData)) diff --git a/internal/handshake/crypto_setup_server_test.go b/internal/handshake/crypto_setup_server_test.go index 0e7133be..0b82e0ac 100644 --- a/internal/handshake/crypto_setup_server_test.go +++ b/internal/handshake/crypto_setup_server_test.go @@ -218,7 +218,7 @@ var _ = Describe("Server Crypto Setup", func() { TagNSTP: []byte("foobar"), }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(ErrNSTPExperiment)) }) @@ -315,7 +315,7 @@ var _ = Describe("Server Crypto Setup", func() { }, }.Write(&stream.dataToRead) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) Expect(handshakeEvent).To(Receive()) // for the switch to secure @@ -327,14 +327,14 @@ var _ = Describe("Server Crypto Setup", func() { It("rejects client nonces that have the wrong length", func() { fullCHLO[TagNONC] = []byte("too short client nonce") HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length"))) }) It("rejects client nonces that have the wrong OBIT value", func() { fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0 HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching"))) }) @@ -342,13 +342,13 @@ var _ = Describe("Server Crypto Setup", func() { testErr := errors.New("test error") kex.sharedKeyError = testErr HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(testErr)) }) It("handles 0-RTT handshake", func() { HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) @@ -400,14 +400,14 @@ var _ = Describe("Server Crypto Setup", func() { TagSNI: []byte("quic.clemente.io"), }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag"))) }) It("rejects CHLOs with a version tag that has the wrong length", func() { fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag"))) }) @@ -420,7 +420,7 @@ var _ = Describe("Server Crypto Setup", func() { binary.BigEndian.PutUint32(b, uint32(lowestSupportedVersion)) fullCHLO[TagVER] = b HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected"))) }) @@ -433,35 +433,35 @@ var _ = Describe("Server Crypto Setup", func() { binary.BigEndian.PutUint32(b, uint32(unsupportedVersion)) fullCHLO[TagVER] = b HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).ToNot(HaveOccurred()) }) It("errors if the AEAD tag is missing", func() { delete(fullCHLO, TagAEAD) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the AEAD tag has the wrong value", func() { fullCHLO[TagAEAD] = []byte("wrong") HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the KEXS tag is missing", func() { delete(fullCHLO, TagKEXS) HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the KEXS tag has the wrong value", func() { fullCHLO[TagKEXS] = []byte("wrong") HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) }) @@ -473,7 +473,7 @@ var _ = Describe("Server Crypto Setup", func() { TagSTK: validSTK, }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) }) @@ -485,19 +485,19 @@ var _ = Describe("Server Crypto Setup", func() { TagSNI: nil, }, }.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) }) It("errors with invalid message", func() { stream.dataToRead.Write([]byte("invalid message")) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.HandshakeFailed)) }) It("errors with non-CHLO message", func() { HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead) - err := cs.HandleCryptoStream() + err := cs.RunHandshake() Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) }) diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 389ecbd1..4a375412 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -1,168 +1,474 @@ package handshake import ( + "bytes" + "crypto/tls" "errors" "fmt" "io" - "sync" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/marten-seemann/qtls" ) -// KeyDerivationFunction is used for key derivation -type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) +type messageType uint8 + +// TLS handshake message types. +const ( + typeClientHello messageType = 1 + typeServerHello messageType = 2 + typeEncryptedExtensions messageType = 8 + typeCertificate messageType = 11 + typeCertificateRequest messageType = 13 + typeCertificateVerify messageType = 15 + typeFinished messageType = 20 +) + +func (m messageType) String() string { + switch m { + case typeClientHello: + return "ClientHello" + case typeServerHello: + return "ServerHello" + case typeEncryptedExtensions: + return "EncryptedExtensions" + case typeCertificate: + return "Certificate" + case typeCertificateRequest: + return "CertificateRequest" + case typeCertificateVerify: + return "CertificateVerify" + case typeFinished: + return "Finished" + default: + return fmt.Sprintf("unknown message type: %d", m) + } +} type cryptoSetupTLS struct { - mutex sync.RWMutex + tlsConf *qtls.Config + + messageChan chan []byte + + readEncLevel protocol.EncryptionLevel + writeEncLevel protocol.EncryptionLevel + + handleParamsCallback func(*TransportParameters) + handshakeEvent chan<- struct{} + handshakeComplete chan<- struct{} + receivedTransportParams <-chan TransportParameters + + clientHelloWritten bool + clientHelloWrittenChan chan struct{} + + initialReadBuf bytes.Buffer + initialStream io.Writer + initialAEAD crypto.AEAD + + handshakeReadBuf bytes.Buffer + handshakeStream io.Writer + handshakeOpener Opener + handshakeSealer Sealer + + opener Opener + sealer Sealer + // TODO: add a 1-RTT stream (used for session tickets) + + receivedWriteKey chan struct{} + receivedReadKey chan struct{} + + logger utils.Logger perspective protocol.Perspective - - keyDerivation KeyDerivationFunction - nullAEAD crypto.AEAD - aead crypto.AEAD - - tls mintTLS - conn *cryptoStreamConn - handshakeEvent chan<- struct{} - handshakeComplete chan<- struct{} } +var _ qtls.RecordLayer = &cryptoSetupTLS{} var _ CryptoSetupTLS = &cryptoSetupTLS{} -// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server -func NewCryptoSetupTLSServer( - cryptoStream io.ReadWriter, - connID protocol.ConnectionID, - config *mint.Config, - handshakeEvent chan<- struct{}, - handshakeComplete chan<- struct{}, - version protocol.VersionNumber, -) (CryptoSetupTLS, error) { - nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version) - if err != nil { - return nil, err - } - conn := newCryptoStreamConn(cryptoStream) - tls := mint.Server(conn, config) - return &cryptoSetupTLS{ - tls: tls, - conn: conn, - nullAEAD: nullAEAD, - perspective: protocol.PerspectiveServer, - keyDerivation: crypto.DeriveAESKeys, - handshakeEvent: handshakeEvent, - handshakeComplete: handshakeComplete, - }, nil +type versionInfo struct { + initialVersion protocol.VersionNumber + supportedVersions []protocol.VersionNumber + currentVersion protocol.VersionNumber } -// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client +// NewCryptoSetupTLSClient creates a new TLS crypto setup for the client func NewCryptoSetupTLSClient( - cryptoStream io.ReadWriter, + initialStream io.Writer, + handshakeStream io.Writer, connID protocol.ConnectionID, - config *mint.Config, + params *TransportParameters, + handleParams func(*TransportParameters), handshakeEvent chan<- struct{}, handshakeComplete chan<- struct{}, - version protocol.VersionNumber, -) (CryptoSetupTLS, error) { - nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version) - if err != nil { - return nil, err - } - conn := newCryptoStreamConn(cryptoStream) - tls := mint.Client(conn, config) - return &cryptoSetupTLS{ - tls: tls, - conn: conn, - perspective: protocol.PerspectiveClient, - nullAEAD: nullAEAD, - keyDerivation: crypto.DeriveAESKeys, - handshakeEvent: handshakeEvent, - handshakeComplete: handshakeComplete, - }, nil + tlsConf *tls.Config, + initialVersion protocol.VersionNumber, + supportedVersions []protocol.VersionNumber, + currentVersion protocol.VersionNumber, + logger utils.Logger, + perspective protocol.Perspective, +) (CryptoSetupTLS, <-chan struct{} /* ClientHello written */, error) { + return newCryptoSetupTLS( + initialStream, + handshakeStream, + connID, + params, + handleParams, + handshakeEvent, + handshakeComplete, + tlsConf, + versionInfo{ + currentVersion: currentVersion, + initialVersion: initialVersion, + supportedVersions: supportedVersions, + }, + logger, + perspective, + ) } -func (h *cryptoSetupTLS) HandleCryptoStream() error { - for { - if alert := h.tls.Handshake(); alert != mint.AlertNoAlert { - return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) - } - state := h.tls.ConnectionState().HandshakeState - if err := h.conn.Flush(); err != nil { - return err - } - if state == mint.StateClientConnected || state == mint.StateServerConnected { - break - } - } +// NewCryptoSetupTLSServer creates a new TLS crypto setup for the server +func NewCryptoSetupTLSServer( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + params *TransportParameters, + handleParams func(*TransportParameters), + handshakeEvent chan<- struct{}, + handshakeComplete chan<- struct{}, + tlsConf *tls.Config, + supportedVersions []protocol.VersionNumber, + currentVersion protocol.VersionNumber, + logger utils.Logger, + perspective protocol.Perspective, +) (CryptoSetupTLS, error) { + cs, _, err := newCryptoSetupTLS( + initialStream, + handshakeStream, + connID, + params, + handleParams, + handshakeEvent, + handshakeComplete, + tlsConf, + versionInfo{ + currentVersion: currentVersion, + supportedVersions: supportedVersions, + }, + logger, + perspective, + ) + return cs, err +} - aead, err := h.keyDerivation(h.tls, h.perspective) +func newCryptoSetupTLS( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + params *TransportParameters, + handleParams func(*TransportParameters), + handshakeEvent chan<- struct{}, + handshakeComplete chan<- struct{}, + tlsConf *tls.Config, + versionInfo versionInfo, + logger utils.Logger, + perspective protocol.Perspective, +) (CryptoSetupTLS, <-chan struct{} /* ClientHello written */, error) { + initialAEAD, err := crypto.NewNullAEAD(perspective, connID, protocol.VersionTLS) if err != nil { + return nil, nil, err + } + cs := &cryptoSetupTLS{ + initialStream: initialStream, + initialAEAD: initialAEAD, + handshakeStream: handshakeStream, + readEncLevel: protocol.EncryptionInitial, + writeEncLevel: protocol.EncryptionInitial, + handleParamsCallback: handleParams, + handshakeEvent: handshakeEvent, + handshakeComplete: handshakeComplete, + logger: logger, + perspective: perspective, + clientHelloWrittenChan: make(chan struct{}), + messageChan: make(chan []byte, 100), + receivedReadKey: make(chan struct{}), + receivedWriteKey: make(chan struct{}), + } + var extHandler tlsExtensionHandler + switch perspective { + case protocol.PerspectiveClient: + extHandler, cs.receivedTransportParams = newExtensionHandlerClient( + params, + versionInfo.initialVersion, + versionInfo.supportedVersions, + versionInfo.currentVersion, + logger, + ) + case protocol.PerspectiveServer: + extHandler, cs.receivedTransportParams = newExtensionHandlerServer( + params, + versionInfo.supportedVersions, + versionInfo.currentVersion, + logger, + ) + } + qtlsConf := tlsConfigToQtlsConfig(tlsConf) + qtlsConf.AlternativeRecordLayer = cs + qtlsConf.GetExtensions = extHandler.GetExtensions + qtlsConf.ReceivedExtensions = extHandler.ReceivedExtensions + cs.tlsConf = qtlsConf + return cs, cs.clientHelloWrittenChan, nil +} + +func (h *cryptoSetupTLS) RunHandshake() error { + var conn *qtls.Conn + switch h.perspective { + case protocol.PerspectiveClient: + conn = qtls.Client(nil, h.tlsConf) + case protocol.PerspectiveServer: + conn = qtls.Server(nil, h.tlsConf) + } + if err := conn.Handshake(); err != nil { + close(h.receivedReadKey) + close(h.receivedWriteKey) return err } - h.mutex.Lock() - h.aead = aead - h.mutex.Unlock() - - h.handshakeEvent <- struct{}{} close(h.handshakeComplete) return nil } -func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - return h.nullAEAD.Open(dst, src, packetNumber, associatedData) +func (h *cryptoSetupTLS) HandleData(data []byte, encLevel protocol.EncryptionLevel) error { + var buf *bytes.Buffer + switch encLevel { + case protocol.EncryptionInitial: + buf = &h.initialReadBuf + case protocol.EncryptionHandshake: + buf = &h.handshakeReadBuf + default: + return fmt.Errorf("received handshake data with unexpected encryption level: %s", encLevel) + } + buf.Write(data) + for buf.Len() >= 4 { + b := buf.Bytes() + // read the TLS message length + length := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) + if buf.Len() < 4+length { // message not yet complete + return nil + } + msg := make([]byte, length+4) + buf.Read(msg) + if err := h.handleMessage(msg, encLevel); err != nil { + return err + } + } + return nil } -func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) { - h.mutex.RLock() - defer h.mutex.RUnlock() - - if h.aead == nil { - return nil, errors.New("no 1-RTT sealer") +// handleMessage handles a TLS handshake message. +// It is called by the crypto streams when a new message is available. +func (h *cryptoSetupTLS) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error { + msgType := messageType(data[0]) + h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) + if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { + return err + } + h.messageChan <- data + switch h.perspective { + case protocol.PerspectiveClient: + return h.handleMessageForClient(msgType) + case protocol.PerspectiveServer: + return h.handleMessageForServer(msgType) + default: + panic("") + } +} + +func (h *cryptoSetupTLS) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { + var expected protocol.EncryptionLevel + switch msgType { + case typeClientHello, + typeServerHello: + expected = protocol.EncryptionInitial + case typeEncryptedExtensions, + typeCertificate, + typeCertificateRequest, + typeCertificateVerify, + typeFinished: + expected = protocol.EncryptionHandshake + default: + return fmt.Errorf("unexpected handshake message: %d", msgType) + } + if encLevel != expected { + return fmt.Errorf("expected handshake message %d to have encryption level %s, has %s", msgType, expected, encLevel) + } + return nil +} + +func (h *cryptoSetupTLS) handleMessageForServer(msgType messageType) error { + switch msgType { + case typeClientHello: + params := <-h.receivedTransportParams + h.handleParamsCallback(¶ms) + <-h.receivedWriteKey // get the handshake write key + <-h.receivedWriteKey // get the 1-RTT write key + <-h.receivedReadKey // get the handshake read key + h.handshakeEvent <- struct{}{} + // TODO: check that the initial stream doesn't have any more data + case typeCertificate, typeCertificateVerify: + // nothing to do + case typeFinished: + <-h.receivedReadKey // get the 1-RTT read key + h.handshakeEvent <- struct{}{} + // TODO: check that the handshake stream doesn't have any more data + default: + // TODO: think about what to do with unknown message types + return fmt.Errorf("Received unknown handshake message: %d", msgType) + } + return nil +} + +func (h *cryptoSetupTLS) handleMessageForClient(msgType messageType) error { + switch msgType { + case typeServerHello: + <-h.receivedReadKey // get the handshake read key + h.handshakeEvent <- struct{}{} + case typeEncryptedExtensions: + params := <-h.receivedTransportParams + h.handleParamsCallback(¶ms) + case typeCertificateRequest, typeCertificate, typeCertificateVerify: + // nothing to do + case typeFinished: + <-h.receivedWriteKey // get the handshake write key + // TODO: check that the initial stream doesn't have any more data + // While the order of these two is not defined by the TLS spec, + // we have to do it on the same order as our TLS library does it. + <-h.receivedWriteKey // get the handshake write key + <-h.receivedReadKey // get the 1-RTT read key + // TODO: check that the handshake stream doesn't have any more data + h.handshakeEvent <- struct{}{} + default: + // TODO: think about what to do with unknown extensions + return fmt.Errorf("Received unknown handshake message: %d", msgType) + } + return nil +} + +// ReadHandshakeMessage is called by TLS. +// It blocks until a new handshake message is available. +func (h *cryptoSetupTLS) ReadHandshakeMessage() ([]byte, error) { + // TODO: add some error handling here (when the session is closed) + return <-h.messageChan, nil +} + +func (h *cryptoSetupTLS) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) { + key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen()) + iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen()) + opener := newOpener(suite.AEAD(key, iv), iv) + + switch h.readEncLevel { + case protocol.EncryptionInitial: + h.readEncLevel = protocol.EncryptionHandshake + h.handshakeOpener = opener + h.logger.Debugf("Installed Handshake Read keys") + case protocol.EncryptionHandshake: + h.readEncLevel = protocol.Encryption1RTT + h.opener = opener + h.logger.Debugf("Installed 1-RTT Read keys") + default: + panic("unexpected read encryption level") + } + h.receivedReadKey <- struct{}{} +} + +func (h *cryptoSetupTLS) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) { + key := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "key", suite.KeyLen()) + iv := crypto.HkdfExpandLabel(suite.Hash(), trafficSecret, "iv", suite.IVLen()) + sealer := newSealer(suite.AEAD(key, iv), iv) + + switch h.writeEncLevel { + case protocol.EncryptionInitial: + h.writeEncLevel = protocol.EncryptionHandshake + h.handshakeSealer = sealer + h.logger.Debugf("Installed Handshake Write keys") + case protocol.EncryptionHandshake: + h.writeEncLevel = protocol.Encryption1RTT + h.sealer = sealer + h.logger.Debugf("Installed 1-RTT Write keys") + default: + panic("unexpected write encryption level") + } + h.receivedWriteKey <- struct{}{} +} + +// WriteRecord is called when TLS writes data +func (h *cryptoSetupTLS) WriteRecord(p []byte) (int, error) { + switch h.writeEncLevel { + case protocol.EncryptionInitial: + // assume that the first WriteRecord call contains the ClientHello + n, err := h.initialStream.Write(p) + if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { + h.clientHelloWritten = true + close(h.clientHelloWrittenChan) + } + return n, err + case protocol.EncryptionHandshake: + return h.handshakeStream.Write(p) + default: + return 0, fmt.Errorf("unexpected write encryption level: %s", h.writeEncLevel) } - return h.aead.Open(dst, src, packetNumber, associatedData) } func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) { - h.mutex.RLock() - defer h.mutex.RUnlock() - - if h.aead != nil { - return protocol.EncryptionForwardSecure, h.aead + if h.sealer != nil { + return protocol.Encryption1RTT, h.sealer } - return protocol.EncryptionUnencrypted, h.nullAEAD + if h.handshakeSealer != nil { + return protocol.EncryptionHandshake, h.handshakeSealer + } + return protocol.EncryptionInitial, h.initialAEAD } -func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { - errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", encLevel.String()) - h.mutex.RLock() - defer h.mutex.RUnlock() +func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) { + errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String()) - switch encLevel { - case protocol.EncryptionUnencrypted: - return h.nullAEAD, nil - case protocol.EncryptionForwardSecure: - if h.aead == nil { + switch level { + case protocol.EncryptionInitial: + return h.initialAEAD, nil + case protocol.EncryptionHandshake: + if h.handshakeSealer == nil { return nil, errNoSealer } - return h.aead, nil + return h.handshakeSealer, nil + case protocol.Encryption1RTT: + if h.sealer == nil { + return nil, errNoSealer + } + return h.sealer, nil default: return nil, errNoSealer } } -func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) { - return protocol.EncryptionUnencrypted, h.nullAEAD +func (h *cryptoSetupTLS) OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + return h.initialAEAD.Open(dst, src, pn, ad) +} + +func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + if h.handshakeOpener == nil { + return nil, errors.New("no handshake opener") + } + return h.handshakeOpener.Open(dst, src, pn, ad) +} + +func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + if h.opener == nil { + return nil, errors.New("no 1-RTT opener") + } + return h.opener.Open(dst, src, pn, ad) } func (h *cryptoSetupTLS) ConnectionState() ConnectionState { - h.mutex.Lock() - defer h.mutex.Unlock() - mintConnState := h.tls.ConnectionState() - return ConnectionState{ - // TODO: set the ServerName, once mint exports it - HandshakeComplete: h.aead != nil, - PeerCertificates: mintConnState.PeerCertificates, - } + // TODO: return the connection state + return ConnectionState{} } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 12b2e9cd..5d98e09e 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -2,194 +2,264 @@ package handshake import ( "bytes" - "errors" - "fmt" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" - "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/crypto" - "github.com/lucas-clemente/quic-go/internal/mocks/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/marten-seemann/qtls" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -func mockKeyDerivation(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error) { - return mockcrypto.NewMockAEAD(mockCtrl), nil +type chunk struct { + data []byte + encLevel protocol.EncryptionLevel } -var _ = Describe("TLS Crypto Setup", func() { - var ( - cs *cryptoSetupTLS - handshakeEvent chan struct{} - handshakeComplete chan struct{} - ) +type stream struct { + encLevel protocol.EncryptionLevel + chunkChan chan<- chunk +} - BeforeEach(func() { - handshakeEvent = make(chan struct{}, 2) - handshakeComplete = make(chan struct{}) - css, err := NewCryptoSetupTLSServer( - newCryptoStreamConn(bytes.NewBuffer([]byte{})), +func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream { + return &stream{ + chunkChan: chunkChan, + encLevel: encLevel, + } +} + +func (s *stream) Write(b []byte) (int, error) { + data := make([]byte, len(b)) + copy(data, b) + select { + case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: + default: + panic("chunkChan too small") + } + return len(b), nil +} + +var _ = Describe("Crypto Setup TLS", func() { + generateCert := func() tls.Certificate { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{}, + SignatureAlgorithm: x509.SHA256WithRSA, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), // valid for an hour + BasicConstraintsValid: true, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) + Expect(err).ToNot(HaveOccurred()) + return tls.Certificate{ + PrivateKey: priv, + Certificate: [][]byte{certDER}, + } + } + + initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { + chunkChan := make(chan chunk, 100) + initialStream := newStream(chunkChan, protocol.EncryptionInitial) + handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) + return chunkChan, initialStream, handshakeStream + } + + handshake := func( + client CryptoSetupTLS, + cChunkChan <-chan chunk, + server CryptoSetupTLS, + sChunkChan <-chan chunk) (error /* client error */, error /* server error */) { + done := make(chan struct{}) + defer close(done) + go func() { + defer GinkgoRecover() + for { + select { + case c := <-cChunkChan: + err := server.HandleData(c.data, c.encLevel) + Expect(err).ToNot(HaveOccurred()) + case c := <-sChunkChan: + err := client.HandleData(c.data, c.encLevel) + Expect(err).ToNot(HaveOccurred()) + case <-done: // handshake complete + } + } + }() + + serverErrChan := make(chan error) + go func() { + defer GinkgoRecover() + serverErrChan <- server.RunHandshake() + }() + + clientErr := client.RunHandshake() + var serverErr error + Eventually(serverErrChan).Should(Receive(&serverErr)) + return clientErr, serverErr + } + + handshakeWithTLSConf := func(clientConf, serverConf *tls.Config) (error /* client error */, error /* server error */) { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + client, _, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, protocol.ConnectionID{}, - &mint.Config{}, - handshakeEvent, - handshakeComplete, + &TransportParameters{}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + clientConf, protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, ) Expect(err).ToNot(HaveOccurred()) - cs = css.(*cryptoSetupTLS) - cs.nullAEAD = mockcrypto.NewMockAEAD(mockCtrl) - }) - It("errors when the handshake fails", func() { - alert := mint.AlertBadRecordMAC - cs.tls = NewMockMintTLS(mockCtrl) - cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(alert) - err := cs.HandleCryptoStream() - Expect(err).To(MatchError(fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert))) - }) - - It("derives keys", func() { - cs.tls = NewMockMintTLS(mockCtrl) - cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) - cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected}) - cs.keyDerivation = mockKeyDerivation - err := cs.HandleCryptoStream() + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + server, err := NewCryptoSetupTLSServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + serverConf, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("server"), + protocol.PerspectiveServer, + ) Expect(err).ToNot(HaveOccurred()) - Expect(handshakeEvent).To(Receive()) - Expect(handshakeComplete).To(BeClosed()) + + return handshake(client, cChunkChan, server, sChunkChan) + } + + It("handshakes", func() { + clientConf := &tls.Config{ServerName: "quic.clemente.io"} + serverConf := testdata.GetTLSConfig() + clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) }) - It("handshakes until it is connected", func() { - cs.tls = NewMockMintTLS(mockCtrl) - cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert).Times(10) - cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerNegotiated}).Times(9) - cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected}) - cs.keyDerivation = mockKeyDerivation - err := cs.HandleCryptoStream() - Expect(err).ToNot(HaveOccurred()) - Expect(handshakeEvent).To(Receive()) - }) - - Context("reporting the handshake state", func() { - It("reports before the handshake compeletes", func() { - cs.tls = NewMockMintTLS(mockCtrl) - cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{}) - state := cs.ConnectionState() - Expect(state.HandshakeComplete).To(BeFalse()) - Expect(state.PeerCertificates).To(BeNil()) - }) - - It("reports after the handshake completes", func() { - cs.tls = NewMockMintTLS(mockCtrl) - cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) - cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected}).Times(2) - cs.keyDerivation = mockKeyDerivation - err := cs.HandleCryptoStream() - Expect(err).ToNot(HaveOccurred()) - state := cs.ConnectionState() - Expect(state.HandshakeComplete).To(BeTrue()) - Expect(state.PeerCertificates).To(BeNil()) - }) - }) - - Context("escalating crypto", func() { - doHandshake := func() { - cs.tls = NewMockMintTLS(mockCtrl) - cs.tls.(*MockMintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) - cs.tls.(*MockMintTLS).EXPECT().ConnectionState().Return(mint.ConnectionState{HandshakeState: mint.StateServerConnected}) - cs.keyDerivation = mockKeyDerivation - err := cs.HandleCryptoStream() - Expect(err).ToNot(HaveOccurred()) + It("handshakes with client auth", func() { + clientConf := &tls.Config{ + ServerName: "quic.clemente.io", + Certificates: []tls.Certificate{generateCert()}, } + serverConf := testdata.GetTLSConfig() + serverConf.ClientAuth = qtls.RequireAnyClientCert + clientErr, serverErr := handshakeWithTLSConf(clientConf, serverConf) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + }) - Context("null encryption", func() { - It("is used initially", func() { - cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar signed")) - enc, sealer := cs.GetSealer() - Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) - d := sealer.Seal(nil, []byte("foobar"), 5, []byte{}) - Expect(d).To(Equal([]byte("foobar signed"))) - }) + It("signals when it has written the ClientHello", func() { + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + client, chChan, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + &TransportParameters{}, + func(p *TransportParameters) {}, + make(chan struct{}, 100), + make(chan struct{}), + &tls.Config{InsecureSkipVerify: true}, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, + ) + Expect(err).ToNot(HaveOccurred()) - It("is used for opening", func() { - cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("foobar enc"), protocol.PacketNumber(10), []byte{}).Return([]byte("foobar"), nil) - d, err := cs.OpenHandshake(nil, []byte("foobar enc"), 10, []byte{}) - Expect(err).ToNot(HaveOccurred()) - Expect(d).To(Equal([]byte("foobar"))) - }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + client.RunHandshake() + close(done) + }() + var ch chunk + Eventually(cChunkChan).Should(Receive(&ch)) + Eventually(chChan).Should(BeClosed()) + // make sure the whole ClientHello was written + Expect(len(ch.data)).To(BeNumerically(">=", 4)) + Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) + length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) + Expect(len(ch.data) - 4).To(Equal(length)) - It("is used for crypto stream", func() { - cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(20), []byte{}).Return([]byte("foobar signed")) - enc, sealer := cs.GetSealerForCryptoStream() - Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) - d := sealer.Seal(nil, []byte("foobar"), 20, []byte{}) - Expect(d).To(Equal([]byte("foobar signed"))) - }) + // make the go routine return + client.HandleData([]byte{1, 0, 0, 1, 0}, protocol.EncryptionInitial) + Eventually(done).Should(BeClosed()) + }) - It("errors if the has the wrong hash", func() { - cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("foobar enc"), protocol.PacketNumber(10), []byte{}).Return(nil, errors.New("authentication failed")) - _, err := cs.OpenHandshake(nil, []byte("foobar enc"), 10, []byte{}) - Expect(err).To(MatchError("authentication failed")) - }) - }) + It("receives transport parameters", func() { + var cTransportParametersRcvd, sTransportParametersRcvd *TransportParameters + cChunkChan, cInitialStream, cHandshakeStream := initStreams() + cTransportParameters := &TransportParameters{IdleTimeout: 0x42 * time.Second} + client, _, err := NewCryptoSetupTLSClient( + cInitialStream, + cHandshakeStream, + protocol.ConnectionID{}, + cTransportParameters, + func(p *TransportParameters) { sTransportParametersRcvd = p }, + make(chan struct{}, 100), + make(chan struct{}), + &tls.Config{ServerName: "quic.clemente.io"}, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("client"), + protocol.PerspectiveClient, + ) + Expect(err).ToNot(HaveOccurred()) - Context("forward-secure encryption", func() { - It("is used for sealing after the handshake completes", func() { - doHandshake() - cs.aead.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar forward sec")) - enc, sealer := cs.GetSealer() - Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) - d := sealer.Seal(nil, []byte("foobar"), 5, []byte{}) - Expect(d).To(Equal([]byte("foobar forward sec"))) - }) + sChunkChan, sInitialStream, sHandshakeStream := initStreams() + sTransportParameters := &TransportParameters{ + IdleTimeout: 0x1337 * time.Second, + StatelessResetToken: bytes.Repeat([]byte{42}, 16), + } + server, err := NewCryptoSetupTLSServer( + sInitialStream, + sHandshakeStream, + protocol.ConnectionID{}, + sTransportParameters, + func(p *TransportParameters) { cTransportParametersRcvd = p }, + make(chan struct{}, 100), + make(chan struct{}), + testdata.GetTLSConfig(), + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + utils.DefaultLogger.WithPrefix("server"), + protocol.PerspectiveServer, + ) + Expect(err).ToNot(HaveOccurred()) - It("is used for opening", func() { - doHandshake() - cs.aead.(*mockcrypto.MockAEAD).EXPECT().Open(nil, []byte("encrypted"), protocol.PacketNumber(6), []byte{}).Return([]byte("decrypted"), nil) - d, err := cs.Open1RTT(nil, []byte("encrypted"), 6, []byte{}) - Expect(err).ToNot(HaveOccurred()) - Expect(d).To(Equal([]byte("decrypted"))) - }) - }) - - Context("forcing encryption levels", func() { - It("forces null encryption", func() { - doHandshake() - cs.nullAEAD.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar signed")) - sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted) - Expect(err).ToNot(HaveOccurred()) - d := sealer.Seal(nil, []byte("foobar"), 5, []byte{}) - Expect(d).To(Equal([]byte("foobar signed"))) - }) - - It("forces forward-secure encryption", func() { - doHandshake() - cs.aead.(*mockcrypto.MockAEAD).EXPECT().Seal(nil, []byte("foobar"), protocol.PacketNumber(5), []byte{}).Return([]byte("foobar forward sec")) - sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) - Expect(err).ToNot(HaveOccurred()) - d := sealer.Seal(nil, []byte("foobar"), 5, []byte{}) - Expect(d).To(Equal([]byte("foobar forward sec"))) - }) - - It("errors if the forward-secure AEAD is not available", func() { - sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) - Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level forward-secure")) - Expect(sealer).To(BeNil()) - }) - - It("never returns a secure AEAD (they don't exist with TLS)", func() { - doHandshake() - sealer, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) - Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level encrypted (not forward-secure)")) - Expect(sealer).To(BeNil()) - }) - - It("errors if no encryption level is specified", func() { - seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified) - Expect(err).To(MatchError("CryptoSetup: no sealer with encryption level unknown")) - Expect(seal).To(BeNil()) - }) - }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + clientErr, serverErr := handshake(client, cChunkChan, server, sChunkChan) + Expect(clientErr).ToNot(HaveOccurred()) + Expect(serverErr).ToNot(HaveOccurred()) + close(done) + }() + Eventually(done).Should(BeClosed()) + Expect(cTransportParametersRcvd).ToNot(BeNil()) + Expect(cTransportParametersRcvd.IdleTimeout).To(Equal(cTransportParameters.IdleTimeout)) + Expect(sTransportParametersRcvd).ToNot(BeNil()) + Expect(sTransportParametersRcvd.IdleTimeout).To(Equal(sTransportParameters.IdleTimeout)) }) }) diff --git a/internal/handshake/crypto_stream_conn.go b/internal/handshake/crypto_stream_conn.go deleted file mode 100644 index a031f90c..00000000 --- a/internal/handshake/crypto_stream_conn.go +++ /dev/null @@ -1,69 +0,0 @@ -package handshake - -import ( - "bytes" - "io" - "net" - "time" -) - -type cryptoStreamConn struct { - buffer *bytes.Buffer - stream io.ReadWriter -} - -var _ net.Conn = &cryptoStreamConn{} - -func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn { - return &cryptoStreamConn{ - stream: stream, - buffer: &bytes.Buffer{}, - } -} - -func (c *cryptoStreamConn) Read(b []byte) (int, error) { - return c.stream.Read(b) -} - -func (c *cryptoStreamConn) Write(p []byte) (int, error) { - return c.buffer.Write(p) -} - -func (c *cryptoStreamConn) Flush() error { - if c.buffer.Len() == 0 { - return nil - } - _, err := c.stream.Write(c.buffer.Bytes()) - c.buffer.Reset() - return err -} - -// Close is not implemented -func (c *cryptoStreamConn) Close() error { - return nil -} - -// LocalAddr is not implemented -func (c *cryptoStreamConn) LocalAddr() net.Addr { - return nil -} - -// RemoteAddr is not implemented -func (c *cryptoStreamConn) RemoteAddr() net.Addr { - return nil -} - -// SetReadDeadline is not implemented -func (c *cryptoStreamConn) SetReadDeadline(time.Time) error { - return nil -} - -// SetWriteDeadline is not implemented -func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error { - return nil -} - -// SetDeadline is not implemented -func (c *cryptoStreamConn) SetDeadline(time.Time) error { - return nil -} diff --git a/internal/handshake/crypto_stream_conn_test.go b/internal/handshake/crypto_stream_conn_test.go deleted file mode 100644 index 64bb6cbd..00000000 --- a/internal/handshake/crypto_stream_conn_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package handshake - -import ( - "bytes" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Crypto Stream Conn", func() { - var ( - stream *bytes.Buffer - csc *cryptoStreamConn - ) - - BeforeEach(func() { - stream = &bytes.Buffer{} - csc = newCryptoStreamConn(stream) - }) - - It("buffers writes", func() { - _, err := csc.Write([]byte("foo")) - Expect(err).ToNot(HaveOccurred()) - Expect(stream.Len()).To(BeZero()) - _, err = csc.Write([]byte("bar")) - Expect(err).ToNot(HaveOccurred()) - Expect(stream.Len()).To(BeZero()) - - Expect(csc.Flush()).To(Succeed()) - Expect(stream.Bytes()).To(Equal([]byte("foobar"))) - }) - - It("reads from the stream", func() { - stream.Write([]byte("foobar")) - b := make([]byte, 6) - n, err := csc.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(b).To(Equal([]byte("foobar"))) - }) -}) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 5813fdd9..a9d71934 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -3,53 +3,51 @@ package handshake import ( "crypto/x509" - "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/marten-seemann/qtls" ) +// Opener opens a packet +type Opener interface { + Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) +} + // Sealer seals a packet type Sealer interface { Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte Overhead() int } -// mintTLS combines some methods needed to interact with mint. -type mintTLS interface { - crypto.TLSExporter - Handshake() mint.Alert -} - -// A TLSExtensionHandler sends and received the QUIC TLS extension. -// It provides the parameters sent by the peer on a channel. -type TLSExtensionHandler interface { - Send(mint.HandshakeType, *mint.ExtensionList) error - Receive(mint.HandshakeType, *mint.ExtensionList) error - GetPeerParams() <-chan TransportParameters +// A tlsExtensionHandler sends and received the QUIC TLS extension. +type tlsExtensionHandler interface { + GetExtensions(msgType uint8) []qtls.Extension + ReceivedExtensions(msgType uint8, exts []qtls.Extension) error } type baseCryptoSetup interface { - HandleCryptoStream() error + RunHandshake() error ConnectionState() ConnectionState GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) - GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) } // CryptoSetup is the crypto setup used by gQUIC type CryptoSetup interface { baseCryptoSetup - Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) + GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) + Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, protocol.EncryptionLevel, error) } // CryptoSetupTLS is the crypto setup used by IETF QUIC type CryptoSetupTLS interface { baseCryptoSetup - OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) - Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) + HandleData([]byte, protocol.EncryptionLevel) error + OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) + OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) + Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) } // ConnectionState records basic details about the QUIC connection. diff --git a/internal/handshake/mock_mint_tls_test.go b/internal/handshake/mock_mint_tls_test.go deleted file mode 100644 index c6e7d50b..00000000 --- a/internal/handshake/mock_mint_tls_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: MintTLS) - -// Package handshake is a generated GoMock package. -package handshake - -import ( - reflect "reflect" - - mint "github.com/bifurcation/mint" - gomock "github.com/golang/mock/gomock" -) - -// MockMintTLS is a mock of MintTLS interface -type MockMintTLS struct { - ctrl *gomock.Controller - recorder *MockMintTLSMockRecorder -} - -// MockMintTLSMockRecorder is the mock recorder for MockMintTLS -type MockMintTLSMockRecorder struct { - mock *MockMintTLS -} - -// NewMockMintTLS creates a new mock instance -func NewMockMintTLS(ctrl *gomock.Controller) *MockMintTLS { - mock := &MockMintTLS{ctrl: ctrl} - mock.recorder = &MockMintTLSMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockMintTLS) EXPECT() *MockMintTLSMockRecorder { - return m.recorder -} - -// ComputeExporter mocks base method -func (m *MockMintTLS) ComputeExporter(arg0 string, arg1 []byte, arg2 int) ([]byte, error) { - ret := m.ctrl.Call(m, "ComputeExporter", arg0, arg1, arg2) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ComputeExporter indicates an expected call of ComputeExporter -func (mr *MockMintTLSMockRecorder) ComputeExporter(arg0, arg1, arg2 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ComputeExporter", reflect.TypeOf((*MockMintTLS)(nil).ComputeExporter), arg0, arg1, arg2) -} - -// ConnectionState mocks base method -func (m *MockMintTLS) ConnectionState() mint.ConnectionState { - ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(mint.ConnectionState) - return ret0 -} - -// ConnectionState indicates an expected call of ConnectionState -func (mr *MockMintTLSMockRecorder) ConnectionState() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockMintTLS)(nil).ConnectionState)) -} - -// Handshake mocks base method -func (m *MockMintTLS) Handshake() mint.Alert { - ret := m.ctrl.Call(m, "Handshake") - ret0, _ := ret[0].(mint.Alert) - return ret0 -} - -// Handshake indicates an expected call of Handshake -func (mr *MockMintTLSMockRecorder) Handshake() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handshake", reflect.TypeOf((*MockMintTLS)(nil).Handshake)) -} diff --git a/internal/handshake/mockgen.go b/internal/handshake/mockgen.go deleted file mode 100644 index 86232720..00000000 --- a/internal/handshake/mockgen.go +++ /dev/null @@ -1,3 +0,0 @@ -package handshake - -//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS" diff --git a/internal/handshake/qtls.go b/internal/handshake/qtls.go new file mode 100644 index 00000000..fb2f0bd4 --- /dev/null +++ b/internal/handshake/qtls.go @@ -0,0 +1,48 @@ +package handshake + +import ( + "crypto/tls" + + "github.com/marten-seemann/qtls" +) + +func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config { + if c == nil { + c = &tls.Config{} + } + // QUIC requires TLS 1.3 or newer + if c.MinVersion < qtls.VersionTLS13 { + c.MinVersion = qtls.VersionTLS13 + } + if c.MaxVersion < qtls.VersionTLS13 { + c.MaxVersion = qtls.VersionTLS13 + } + return &qtls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + // TODO: make GetCertificate work + // GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + // TODO: make GetConfigForClient work + // GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + } +} diff --git a/internal/handshake/tls_extension.go b/internal/handshake/tls_extension.go index b3665dfe..54b6d642 100644 --- a/internal/handshake/tls_extension.go +++ b/internal/handshake/tls_extension.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" ) @@ -102,22 +101,3 @@ func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error { } return p.Parameters.unmarshal(data) } - -type tlsExtensionBody struct { - data []byte -} - -var _ mint.ExtensionBody = &tlsExtensionBody{} - -func (e *tlsExtensionBody) Type() mint.ExtensionType { - return quicTLSExtensionType -} - -func (e *tlsExtensionBody) Marshal() ([]byte, error) { - return e.data, nil -} - -func (e *tlsExtensionBody) Unmarshal(data []byte) (int, error) { - e.data = data - return len(data), nil -} diff --git a/internal/handshake/tls_extension_handler_client.go b/internal/handshake/tls_extension_handler_client.go index d03021ae..182b256a 100644 --- a/internal/handshake/tls_extension_handler_client.go +++ b/internal/handshake/tls_extension_handler_client.go @@ -2,18 +2,16 @@ package handshake import ( "errors" - "fmt" - "github.com/lucas-clemente/quic-go/qerr" - - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" + "github.com/marten-seemann/qtls" ) type extensionHandlerClient struct { ourParams *TransportParameters - paramsChan chan TransportParameters + paramsChan chan<- TransportParameters initialVersion protocol.VersionNumber supportedVersions []protocol.VersionNumber @@ -22,17 +20,16 @@ type extensionHandlerClient struct { logger utils.Logger } -var _ mint.AppExtensionHandler = &extensionHandlerClient{} -var _ TLSExtensionHandler = &extensionHandlerClient{} +var _ tlsExtensionHandler = &extensionHandlerClient{} -// NewExtensionHandlerClient creates a new extension handler for the client. -func NewExtensionHandlerClient( +// newExtensionHandlerClient creates a new extension handler for the client. +func newExtensionHandlerClient( params *TransportParameters, initialVersion protocol.VersionNumber, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, logger utils.Logger, -) TLSExtensionHandler { +) (tlsExtensionHandler, <-chan TransportParameters) { // The client reads the transport parameters from the Encrypted Extensions message. // The paramsChan is used in the session's run loop's select statement. // We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately. @@ -44,44 +41,43 @@ func NewExtensionHandlerClient( supportedVersions: supportedVersions, version: version, logger: logger, - } + }, paramsChan } -func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error { - if hType != mint.HandshakeTypeClientHello { +func (h *extensionHandlerClient) GetExtensions(msgType uint8) []qtls.Extension { + if messageType(msgType) != typeClientHello { return nil } h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams) - chtp := &clientHelloTransportParameters{ - InitialVersion: h.initialVersion, - Parameters: *h.ourParams, - } - return el.Add(&tlsExtensionBody{data: chtp.Marshal()}) + return []qtls.Extension{{ + Type: quicTLSExtensionType, + Data: (&clientHelloTransportParameters{ + InitialVersion: h.initialVersion, + Parameters: *h.ourParams, + }).Marshal(), + }} } -func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { - ext := &tlsExtensionBody{} - found, err := el.Find(ext) - if err != nil { - return err - } - - if hType != mint.HandshakeTypeEncryptedExtensions { - if found { - return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType) - } +func (h *extensionHandlerClient) ReceivedExtensions(msgType uint8, exts []qtls.Extension) error { + if messageType(msgType) != typeEncryptedExtensions { return nil } - // hType == mint.HandshakeTypeEncryptedExtensions + var found bool + eetp := &encryptedExtensionsTransportParameters{} + for _, ext := range exts { + if ext.Type != quicTLSExtensionType { + continue + } + if err := eetp.Unmarshal(ext.Data); err != nil { + return err + } + found = true + } if !found { return errors.New("EncryptedExtensions message didn't contain a QUIC extension") } - eetp := &encryptedExtensionsTransportParameters{} - if err := eetp.Unmarshal(ext.data); err != nil { - return err - } // check that the negotiated_version is the current version if eetp.NegotiatedVersion != h.version { return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version") @@ -106,7 +102,3 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte h.paramsChan <- eetp.Parameters return nil } - -func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters { - return h.paramsChan -} diff --git a/internal/handshake/tls_extension_handler_client_test.go b/internal/handshake/tls_extension_handler_client_test.go index bbb72f7c..9113498a 100644 --- a/internal/handshake/tls_extension_handler_client_test.go +++ b/internal/handshake/tls_extension_handler_client_test.go @@ -2,12 +2,11 @@ package handshake import ( "bytes" - "fmt" "time" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/marten-seemann/qtls" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -15,119 +14,99 @@ import ( var _ = Describe("TLS Extension Handler, for the client", func() { var ( - handler *extensionHandlerClient - el mint.ExtensionList + handler *extensionHandlerClient + paramsChan <-chan TransportParameters ) BeforeEach(func() { - handler = NewExtensionHandlerClient(&TransportParameters{}, protocol.VersionWhatever, nil, protocol.VersionWhatever, utils.DefaultLogger).(*extensionHandlerClient) - el = make(mint.ExtensionList, 0) + var h tlsExtensionHandler + h, paramsChan = newExtensionHandlerClient( + &TransportParameters{}, + protocol.VersionWhatever, + nil, + protocol.VersionWhatever, + utils.DefaultLogger, + ) + handler = h.(*extensionHandlerClient) }) Context("sending", func() { It("only adds TransportParameters for the ClientHello", func() { // test 2 other handshake types - err := handler.Send(mint.HandshakeTypeCertificateRequest, &el) - Expect(err).ToNot(HaveOccurred()) - Expect(el).To(BeEmpty()) - err = handler.Send(mint.HandshakeTypeEndOfEarlyData, &el) - Expect(err).ToNot(HaveOccurred()) - Expect(el).To(BeEmpty()) + exts := handler.GetExtensions(uint8(typeCertificateRequest)) + Expect(exts).To(BeEmpty()) + exts = handler.GetExtensions(uint8(typeEncryptedExtensions)) + Expect(exts).To(BeEmpty()) }) It("adds TransportParameters to the ClientHello", func() { handler.initialVersion = 13 - err := handler.Send(mint.HandshakeTypeClientHello, &el) - Expect(err).ToNot(HaveOccurred()) - Expect(el).To(HaveLen(1)) - ext := &tlsExtensionBody{} - found, err := el.Find(ext) - Expect(err).ToNot(HaveOccurred()) - Expect(found).To(BeTrue()) + exts := handler.GetExtensions(uint8(typeClientHello)) + Expect(exts).To(HaveLen(1)) chtp := &clientHelloTransportParameters{} - err = chtp.Unmarshal(ext.data) - Expect(err).ToNot(HaveOccurred()) + Expect(chtp.Unmarshal(exts[0].Data)).To(Succeed()) Expect(chtp.InitialVersion).To(BeEquivalentTo(13)) }) }) Context("receiving", func() { - var fakeBody *tlsExtensionBody var parameters TransportParameters - addEncryptedExtensionsWithParameters := func(params TransportParameters) { - body := (&encryptedExtensionsTransportParameters{ - Parameters: params, - SupportedVersions: []protocol.VersionNumber{handler.version}, - }).Marshal() - Expect(el.Add(&tlsExtensionBody{data: body})).To(Succeed()) + getEncryptedExtensions := func(params TransportParameters) qtls.Extension { + return qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&encryptedExtensionsTransportParameters{ + Parameters: params, + SupportedVersions: []protocol.VersionNumber{handler.version}, + }).Marshal(), + } } BeforeEach(func() { - fakeBody = &tlsExtensionBody{data: []byte("foobar foobar")} parameters = TransportParameters{ IdleTimeout: 0x1337 * time.Second, StatelessResetToken: bytes.Repeat([]byte{0}, 16), } }) - It("blocks until the transport parameters are read", func() { + It("sends the transport parameters on the channel", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - addEncryptedExtensionsWithParameters(parameters) - err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) - Expect(err).ToNot(HaveOccurred()) - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - Expect(handler.GetPeerParams()).To(Receive()) - Eventually(done).Should(BeClosed()) - }) - - It("accepts the TransportParameters on the EncryptedExtensions message", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - addEncryptedExtensionsWithParameters(parameters) - err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + ext := getEncryptedExtensions(parameters) + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).ToNot(HaveOccurred()) close(done) }() var params TransportParameters - Eventually(handler.GetPeerParams()).Should(Receive(¶ms)) + Consistently(done).ShouldNot(BeClosed()) + Expect(paramsChan).To(Receive(¶ms)) Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second)) Eventually(done).Should(BeClosed()) }) It("errors if the EncryptedExtensions message doesn't contain TransportParameters", func() { - err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), nil) Expect(err).To(MatchError("EncryptedExtensions message didn't contain a QUIC extension")) }) - It("rejects the TransportParameters on a wrong handshake types", func() { - err := el.Add(fakeBody) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeCertificate, &el) - Expect(err).To(MatchError(fmt.Sprintf("Unexpected QUIC extension in handshake message %d", mint.HandshakeTypeCertificate))) - }) - It("ignores messages without TransportParameters, if they are not required", func() { - err := handler.Receive(mint.HandshakeTypeCertificate, &el) - Expect(err).ToNot(HaveOccurred()) + Expect(handler.ReceivedExtensions(uint8(typeCertificateVerify), nil)).To(Succeed()) }) It("errors when it can't parse the TransportParameters", func() { - err := el.Add(fakeBody) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: []byte("invalid extension data"), + } + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).To(HaveOccurred()) // this will be some kind of decoding error }) It("rejects TransportParameters if they don't contain the stateless reset token", func() { parameters.StatelessResetToken = nil - addEncryptedExtensionsWithParameters(parameters) - err := handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + ext := getEncryptedExtensions(parameters) + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).To(MatchError("server didn't sent stateless_reset_token")) }) @@ -136,21 +115,22 @@ var _ = Describe("TLS Extension Handler, for the client", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - Eventually(handler.GetPeerParams()).Should(Receive()) + Eventually(paramsChan).Should(Receive()) close(done) }() handler.initialVersion = 13 handler.version = 37 handler.supportedVersions = []protocol.VersionNumber{13, 37, 42} - body := (&encryptedExtensionsTransportParameters{ - Parameters: parameters, - NegotiatedVersion: 37, - SupportedVersions: []protocol.VersionNumber{36, 37, 38}, - }).Marshal() - err := el.Add(&tlsExtensionBody{data: body}) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&encryptedExtensionsTransportParameters{ + Parameters: parameters, + NegotiatedVersion: 37, + SupportedVersions: []protocol.VersionNumber{36, 37, 38}, + }).Marshal(), + } + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).ToNot(HaveOccurred()) Eventually(done).Should(BeClosed()) }) @@ -159,26 +139,28 @@ var _ = Describe("TLS Extension Handler, for the client", func() { handler.initialVersion = 13 handler.version = 37 handler.supportedVersions = []protocol.VersionNumber{13, 37, 42} - body := (&encryptedExtensionsTransportParameters{ - Parameters: parameters, - NegotiatedVersion: 38, - SupportedVersions: []protocol.VersionNumber{36, 37, 38}, - }).Marshal() - err := el.Add(&tlsExtensionBody{data: body}) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&encryptedExtensionsTransportParameters{ + Parameters: parameters, + NegotiatedVersion: 38, + SupportedVersions: []protocol.VersionNumber{36, 37, 38}, + }).Marshal(), + } + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).To(MatchError("VersionNegotiationMismatch: current version doesn't match negotiated_version")) }) It("errors if the current version is not contained in the server's supported versions", func() { handler.version = 42 - body := (&encryptedExtensionsTransportParameters{ - NegotiatedVersion: 42, - SupportedVersions: []protocol.VersionNumber{43, 44}, - }).Marshal() - err := el.Add(&tlsExtensionBody{data: body}) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&encryptedExtensionsTransportParameters{ + NegotiatedVersion: 42, + SupportedVersions: []protocol.VersionNumber{43, 44}, + }).Marshal(), + } + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).To(MatchError("VersionNegotiationMismatch: current version not included in the supported versions")) }) @@ -191,13 +173,14 @@ var _ = Describe("TLS Extension Handler, for the client", func() { ver, ok := protocol.ChooseSupportedVersion(handler.supportedVersions, serverSupportedVersions) Expect(ok).To(BeTrue()) Expect(ver).To(Equal(protocol.VersionNumber(43))) - body := (&encryptedExtensionsTransportParameters{ - NegotiatedVersion: 42, - SupportedVersions: serverSupportedVersions, - }).Marshal() - err := el.Add(&tlsExtensionBody{data: body}) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&encryptedExtensionsTransportParameters{ + NegotiatedVersion: 42, + SupportedVersions: serverSupportedVersions, + }).Marshal(), + } + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).To(MatchError("VersionNegotiationMismatch: would have picked a different version")) }) @@ -205,7 +188,7 @@ var _ = Describe("TLS Extension Handler, for the client", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - Eventually(handler.GetPeerParams()).Should(Receive()) + Eventually(paramsChan).Should(Receive()) close(done) }() @@ -217,14 +200,15 @@ var _ = Describe("TLS Extension Handler, for the client", func() { ver, ok := protocol.ChooseSupportedVersion(handler.supportedVersions, serverSupportedVersions) Expect(ok).To(BeTrue()) Expect(ver).To(Equal(protocol.VersionNumber(43))) - body := (&encryptedExtensionsTransportParameters{ - Parameters: parameters, - NegotiatedVersion: 42, - SupportedVersions: serverSupportedVersions, - }).Marshal() - err := el.Add(&tlsExtensionBody{data: body}) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeEncryptedExtensions, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&encryptedExtensionsTransportParameters{ + Parameters: parameters, + NegotiatedVersion: 42, + SupportedVersions: serverSupportedVersions, + }).Marshal(), + } + err := handler.ReceivedExtensions(uint8(typeEncryptedExtensions), []qtls.Extension{ext}) Expect(err).ToNot(HaveOccurred()) Eventually(done).Should(BeClosed()) }) diff --git a/internal/handshake/tls_extension_handler_server.go b/internal/handshake/tls_extension_handler_server.go index 2d75d693..6755d899 100644 --- a/internal/handshake/tls_extension_handler_server.go +++ b/internal/handshake/tls_extension_handler_server.go @@ -2,18 +2,16 @@ package handshake import ( "errors" - "fmt" - "github.com/lucas-clemente/quic-go/qerr" - - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/qerr" + "github.com/marten-seemann/qtls" ) type extensionHandlerServer struct { ourParams *TransportParameters - paramsChan chan TransportParameters + paramsChan chan<- TransportParameters version protocol.VersionNumber supportedVersions []protocol.VersionNumber @@ -21,62 +19,60 @@ type extensionHandlerServer struct { logger utils.Logger } -var _ mint.AppExtensionHandler = &extensionHandlerServer{} -var _ TLSExtensionHandler = &extensionHandlerServer{} +var _ tlsExtensionHandler = &extensionHandlerServer{} -// NewExtensionHandlerServer creates a new extension handler for the server -func NewExtensionHandlerServer( +// newExtensionHandlerServer creates a new extension handler for the server +func newExtensionHandlerServer( params *TransportParameters, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber, logger utils.Logger, -) TLSExtensionHandler { +) (tlsExtensionHandler, <-chan TransportParameters) { // Processing the ClientHello is performed statelessly (and from a single go-routine). // Therefore, we have to use a buffered chan to pass the transport parameters to that go routine. - paramsChan := make(chan TransportParameters, 1) + paramsChan := make(chan TransportParameters) return &extensionHandlerServer{ ourParams: params, paramsChan: paramsChan, supportedVersions: supportedVersions, version: version, logger: logger, - } + }, paramsChan } -func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error { - if hType != mint.HandshakeTypeEncryptedExtensions { +func (h *extensionHandlerServer) GetExtensions(msgType uint8) []qtls.Extension { + if messageType(msgType) != typeEncryptedExtensions { return nil } h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams) - eetp := &encryptedExtensionsTransportParameters{ - NegotiatedVersion: h.version, - SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions), - Parameters: *h.ourParams, - } - return el.Add(&tlsExtensionBody{data: eetp.Marshal()}) + return []qtls.Extension{{ + Type: quicTLSExtensionType, + Data: (&encryptedExtensionsTransportParameters{ + NegotiatedVersion: h.version, + SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions), + Parameters: *h.ourParams, + }).Marshal(), + }} } -func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error { - ext := &tlsExtensionBody{} - found, err := el.Find(ext) - if err != nil { - return err - } - - if hType != mint.HandshakeTypeClientHello { - if found { - return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType) - } +func (h *extensionHandlerServer) ReceivedExtensions(msgType uint8, exts []qtls.Extension) error { + if messageType(msgType) != typeClientHello { return nil } - + var found bool + chtp := &clientHelloTransportParameters{} + for _, ext := range exts { + if ext.Type != quicTLSExtensionType { + continue + } + if err := chtp.Unmarshal(ext.Data); err != nil { + return err + } + found = true + } if !found { return errors.New("ClientHello didn't contain a QUIC extension") } - chtp := &clientHelloTransportParameters{} - if err := chtp.Unmarshal(ext.data); err != nil { - return err - } // perform the stateless version negotiation validation: // make sure that we would have sent a Version Negotiation Packet if the client offered the initial version @@ -94,7 +90,3 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte h.paramsChan <- chtp.Parameters return nil } - -func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters { - return h.paramsChan -} diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index 9563febe..01e8b72f 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -2,12 +2,11 @@ package handshake import ( "bytes" - "fmt" "time" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/marten-seemann/qtls" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -15,40 +14,38 @@ import ( var _ = Describe("TLS Extension Handler, for the server", func() { var ( - handler *extensionHandlerServer - el mint.ExtensionList + handler *extensionHandlerServer + paramsChan <-chan TransportParameters ) BeforeEach(func() { - handler = NewExtensionHandlerServer(&TransportParameters{}, nil, protocol.VersionWhatever, utils.DefaultLogger).(*extensionHandlerServer) - el = make(mint.ExtensionList, 0) + var h tlsExtensionHandler + h, paramsChan = newExtensionHandlerServer( + &TransportParameters{}, + nil, + protocol.VersionWhatever, + utils.DefaultLogger, + ) + handler = h.(*extensionHandlerServer) }) Context("sending", func() { - It("only adds TransportParameters for the ClientHello", func() { + It("only adds TransportParameters for the Encrypted Extensions", func() { // test 2 other handshake types - err := handler.Send(mint.HandshakeTypeCertificateRequest, &el) - Expect(err).ToNot(HaveOccurred()) - Expect(el).To(BeEmpty()) - err = handler.Send(mint.HandshakeTypeEndOfEarlyData, &el) - Expect(err).ToNot(HaveOccurred()) - Expect(el).To(BeEmpty()) + exts := handler.GetExtensions(uint8(typeCertificate)) + Expect(exts).To(BeEmpty()) + exts = handler.GetExtensions(uint8(typeFinished)) + Expect(exts).To(BeEmpty()) }) It("adds TransportParameters to the EncryptedExtensions message", func() { handler.version = 666 versions := []protocol.VersionNumber{13, 37, 42} handler.supportedVersions = versions - err := handler.Send(mint.HandshakeTypeEncryptedExtensions, &el) - Expect(err).ToNot(HaveOccurred()) - Expect(el).To(HaveLen(1)) - ext := &tlsExtensionBody{} - found, err := el.Find(ext) - Expect(err).ToNot(HaveOccurred()) - Expect(found).To(BeTrue()) + exts := handler.GetExtensions(uint8(typeEncryptedExtensions)) + Expect(exts).To(HaveLen(1)) eetp := &encryptedExtensionsTransportParameters{} - err = eetp.Unmarshal(ext.data) - Expect(err).ToNot(HaveOccurred()) + Expect(eetp.Unmarshal(exts[0].Data)).To(Succeed()) Expect(eetp.NegotiatedVersion).To(BeEquivalentTo(666)) // the SupportedVersions will contain one reserved version number Expect(eetp.SupportedVersions).To(HaveLen(len(versions) + 1)) @@ -59,95 +56,108 @@ var _ = Describe("TLS Extension Handler, for the server", func() { }) Context("receiving", func() { - var ( - fakeBody *tlsExtensionBody - parameters TransportParameters - ) + var parameters TransportParameters - addClientHelloWithParameters := func(params TransportParameters) { - body := (&clientHelloTransportParameters{Parameters: params}).Marshal() - Expect(el.Add(&tlsExtensionBody{data: body})).To(Succeed()) + getClientHello := func(params TransportParameters) qtls.Extension { + return qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&clientHelloTransportParameters{Parameters: params}).Marshal(), + } } BeforeEach(func() { - fakeBody = &tlsExtensionBody{data: []byte("foobar foobar")} parameters = TransportParameters{IdleTimeout: 0x1337 * time.Second} }) - It("accepts the TransportParameters on the EncryptedExtensions message", func() { - addClientHelloWithParameters(parameters) - err := handler.Receive(mint.HandshakeTypeClientHello, &el) - Expect(err).ToNot(HaveOccurred()) + It("sends the transport parameters on the channel", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + ext := getClientHello(parameters) + err := handler.ReceivedExtensions(uint8(typeClientHello), []qtls.Extension{ext}) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() var params TransportParameters - Expect(handler.GetPeerParams()).To(Receive(¶ms)) + Consistently(done).ShouldNot(BeClosed()) + Expect(paramsChan).To(Receive(¶ms)) Expect(params.IdleTimeout).To(Equal(0x1337 * time.Second)) + Eventually(done).Should(BeClosed()) }) It("errors if the ClientHello doesn't contain TransportParameters", func() { - err := handler.Receive(mint.HandshakeTypeClientHello, &el) + err := handler.ReceivedExtensions(uint8(typeClientHello), nil) Expect(err).To(MatchError("ClientHello didn't contain a QUIC extension")) }) - It("ignores messages without TransportParameters, if they are not required", func() { - err := handler.Receive(mint.HandshakeTypeCertificate, &el) - Expect(err).ToNot(HaveOccurred()) - }) - It("errors if it can't unmarshal the TransportParameters", func() { - err := el.Add(fakeBody) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeClientHello, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: []byte("invalid extension data"), + } + err := handler.ReceivedExtensions(uint8(typeClientHello), []qtls.Extension{ext}) Expect(err).To(HaveOccurred()) // this will be some kind of decoding error }) - It("rejects messages other than the ClientHello that contain TransportParameters", func() { - addClientHelloWithParameters(parameters) - err := handler.Receive(mint.HandshakeTypeCertificateRequest, &el) - Expect(err).To(MatchError(fmt.Sprintf("Unexpected QUIC extension in handshake message %d", mint.HandshakeTypeCertificateRequest))) - }) - It("rejects messages that contain a stateless reset token", func() { parameters.StatelessResetToken = bytes.Repeat([]byte{0}, 16) - addClientHelloWithParameters(parameters) - err := handler.Receive(mint.HandshakeTypeClientHello, &el) + ext := getClientHello(parameters) + err := handler.ReceivedExtensions(uint8(typeClientHello), []qtls.Extension{ext}) Expect(err).To(MatchError("client sent a stateless reset token")) }) Context("Version Negotiation", func() { It("accepts a ClientHello, when no version negotiation was performed", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + <-paramsChan + close(done) + }() handler.version = 42 - body := (&clientHelloTransportParameters{ - InitialVersion: 42, - Parameters: parameters, - }).Marshal() - err := el.Add(&tlsExtensionBody{data: body}) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeClientHello, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&clientHelloTransportParameters{ + InitialVersion: 42, + Parameters: parameters, + }).Marshal(), + } + err := handler.ReceivedExtensions(uint8(typeClientHello), []qtls.Extension{ext}) Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) }) It("accepts a valid version negotiation", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + <-paramsChan + close(done) + }() handler.version = 42 handler.supportedVersions = []protocol.VersionNumber{13, 37, 42} - body := (&clientHelloTransportParameters{ - InitialVersion: 22, // this must be an unsupported version - Parameters: parameters, - }).Marshal() - err := el.Add(&tlsExtensionBody{data: body}) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeClientHello, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&clientHelloTransportParameters{ + InitialVersion: 22, // this must be an unsupported version + Parameters: parameters, + }).Marshal(), + } + err := handler.ReceivedExtensions(uint8(typeClientHello), []qtls.Extension{ext}) Expect(err).ToNot(HaveOccurred()) + Eventually(done).Should(BeClosed()) }) It("erros when a version negotiation was performed, although we already support the initial version", func() { handler.supportedVersions = []protocol.VersionNumber{11, 12, 13} handler.version = 13 - body := (&clientHelloTransportParameters{ - InitialVersion: 11, // this is an supported version - }).Marshal() - err := el.Add(&tlsExtensionBody{data: body}) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeClientHello, &el) + ext := qtls.Extension{ + Type: quicTLSExtensionType, + Data: (&clientHelloTransportParameters{ + InitialVersion: 11, // this is an supported version + }).Marshal(), + } + err := handler.ReceivedExtensions(uint8(typeClientHello), []qtls.Extension{ext}) Expect(err).To(MatchError("VersionNegotiationMismatch: Client should have used the initial version")) }) }) diff --git a/internal/handshake/tls_extension_test.go b/internal/handshake/tls_extension_test.go index ef0a6dc0..27e705d6 100644 --- a/internal/handshake/tls_extension_test.go +++ b/internal/handshake/tls_extension_test.go @@ -10,7 +10,7 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("TLS extension body", func() { +var _ = Describe("QUIC TLS Extension", func() { Context("Client Hello Transport Parameters", func() { It("marshals and unmarshals", func() { chtp := &clientHelloTransportParameters{ @@ -66,30 +66,4 @@ var _ = Describe("TLS extension body", func() { } }) }) - - Context("TLS Extension Body", func() { - var extBody *tlsExtensionBody - - BeforeEach(func() { - extBody = &tlsExtensionBody{} - }) - - It("has the right TLS extension type", func() { - Expect(extBody.Type()).To(BeEquivalentTo(quicTLSExtensionType)) - }) - - It("saves the body when unmarshalling", func() { - n, err := extBody.Unmarshal([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(6)) - Expect(extBody.data).To(Equal([]byte("foobar"))) - }) - - It("returns the body when marshalling", func() { - extBody.data = []byte("foo") - data, err := extBody.Marshal() - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foo"))) - }) - }) }) diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index b01e97e4..158cb339 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,6 +1,5 @@ package mocks -//go:generate sh -c "../mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler" //go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer" //go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" //go:generate sh -c "../mockgen_internal.sh mockackhandler ackhandler/sent_packet_handler.go github.com/lucas-clemente/quic-go/internal/ackhandler SentPacketHandler" diff --git a/internal/mocks/tls_extension_handler.go b/internal/mocks/tls_extension_handler.go deleted file mode 100644 index fcceee2e..00000000 --- a/internal/mocks/tls_extension_handler.go +++ /dev/null @@ -1,72 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: TLSExtensionHandler) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - mint "github.com/bifurcation/mint" - gomock "github.com/golang/mock/gomock" - handshake "github.com/lucas-clemente/quic-go/internal/handshake" -) - -// MockTLSExtensionHandler is a mock of TLSExtensionHandler interface -type MockTLSExtensionHandler struct { - ctrl *gomock.Controller - recorder *MockTLSExtensionHandlerMockRecorder -} - -// MockTLSExtensionHandlerMockRecorder is the mock recorder for MockTLSExtensionHandler -type MockTLSExtensionHandlerMockRecorder struct { - mock *MockTLSExtensionHandler -} - -// NewMockTLSExtensionHandler creates a new mock instance -func NewMockTLSExtensionHandler(ctrl *gomock.Controller) *MockTLSExtensionHandler { - mock := &MockTLSExtensionHandler{ctrl: ctrl} - mock.recorder = &MockTLSExtensionHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockTLSExtensionHandler) EXPECT() *MockTLSExtensionHandlerMockRecorder { - return m.recorder -} - -// GetPeerParams mocks base method -func (m *MockTLSExtensionHandler) GetPeerParams() <-chan handshake.TransportParameters { - ret := m.ctrl.Call(m, "GetPeerParams") - ret0, _ := ret[0].(<-chan handshake.TransportParameters) - return ret0 -} - -// GetPeerParams indicates an expected call of GetPeerParams -func (mr *MockTLSExtensionHandlerMockRecorder) GetPeerParams() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerParams", reflect.TypeOf((*MockTLSExtensionHandler)(nil).GetPeerParams)) -} - -// Receive mocks base method -func (m *MockTLSExtensionHandler) Receive(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error { - ret := m.ctrl.Call(m, "Receive", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Receive indicates an expected call of Receive -func (mr *MockTLSExtensionHandlerMockRecorder) Receive(arg0, arg1 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Receive), arg0, arg1) -} - -// Send mocks base method -func (m *MockTLSExtensionHandler) Send(arg0 mint.HandshakeType, arg1 *mint.ExtensionList) error { - ret := m.ctrl.Call(m, "Send", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Send indicates an expected call of Send -func (mr *MockTLSExtensionHandlerMockRecorder) Send(arg0, arg1 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockTLSExtensionHandler)(nil).Send), arg0, arg1) -} diff --git a/internal/protocol/server_parameters.go b/internal/protocol/server_parameters.go index aa92c822..8ad8f079 100644 --- a/internal/protocol/server_parameters.go +++ b/internal/protocol/server_parameters.go @@ -110,6 +110,10 @@ const CryptoMaxParams = 128 // CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message. const CryptoParameterMaxLength = 4000 +// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. +// This limits the size of the ClientHello and Certificates that can be received. +const MaxCryptoStreamOffset = 16 * (1 << 10) + // EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX. const EphermalKeyLifetime = time.Minute diff --git a/internal/protocol/stream_id.go b/internal/protocol/stream_id.go index a0dced0c..7b211c03 100644 --- a/internal/protocol/stream_id.go +++ b/internal/protocol/stream_id.go @@ -14,7 +14,7 @@ func MaxBidiStreamID(numStreams int, pers Perspective) StreamID { if pers == PerspectiveClient { first = 1 } else { - first = 4 + first = 0 } return first + 4*StreamID(numStreams-1) } diff --git a/internal/protocol/stream_id_test.go b/internal/protocol/stream_id_test.go index cca4f928..306ae31a 100644 --- a/internal/protocol/stream_id_test.go +++ b/internal/protocol/stream_id_test.go @@ -14,12 +14,12 @@ var _ = Describe("Stream ID", func() { It("allows one", func() { Expect(MaxBidiStreamID(1, PerspectiveClient)).To(Equal(StreamID(1))) - Expect(MaxBidiStreamID(1, PerspectiveServer)).To(Equal(StreamID(4))) + Expect(MaxBidiStreamID(1, PerspectiveServer)).To(Equal(StreamID(0))) }) It("allows many", func() { Expect(MaxBidiStreamID(100, PerspectiveClient)).To(Equal(StreamID(397))) - Expect(MaxBidiStreamID(100, PerspectiveServer)).To(Equal(StreamID(400))) + Expect(MaxBidiStreamID(100, PerspectiveServer)).To(Equal(StreamID(396))) }) }) diff --git a/internal/protocol/version.go b/internal/protocol/version.go index 299c6fcc..4339086a 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -68,12 +68,10 @@ func (vn VersionNumber) ToAltSvc() string { return fmt.Sprintf("%d", vn) } -// CryptoStreamID gets the Stream ID of the crypto stream -func (vn VersionNumber) CryptoStreamID() StreamID { - if vn.isGQUIC() { - return 1 - } - return 0 +// IsCryptoStream says if a stream is the gQUIC crypto stream. +// It never returns true for IETF QUIC. +func (vn VersionNumber) IsCryptoStream(id StreamID) bool { + return vn.isGQUIC() && id == 1 } // UsesIETFFrameFormat tells if this version uses the IETF frame format @@ -108,13 +106,10 @@ func (vn VersionNumber) UsesVarintPacketNumbers() bool { // StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool { - if id == vn.CryptoStreamID() { - return false + if !vn.isGQUIC() { + return true } - if vn.isGQUIC() && id == 3 { - return false - } - return true + return id != 1 && id != 3 } func (vn VersionNumber) isGQUIC() bool { diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index 36326f1c..a294f31b 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -65,11 +65,22 @@ var _ = Describe("Version", func() { Expect(VersionNumber(0x51303438).ToAltSvc()).To(Equal("48")) }) - It("tells the Stream ID of the crypto stream", func() { - Expect(Version39.CryptoStreamID()).To(Equal(StreamID(1))) - Expect(Version43.CryptoStreamID()).To(Equal(StreamID(1))) - Expect(Version44.CryptoStreamID()).To(Equal(StreamID(1))) - Expect(VersionTLS.CryptoStreamID()).To(Equal(StreamID(0))) + It("says if a stream is the crypto stream, for gQUIC", func() { + for _, v := range []VersionNumber{Version39, Version43, Version44} { + version := v + Expect(version.IsCryptoStream(1)).To(BeTrue()) + Expect(version.IsCryptoStream(2)).To(BeFalse()) + Expect(version.IsCryptoStream(3)).To(BeFalse()) + Expect(version.IsCryptoStream(4)).To(BeFalse()) + Expect(version.IsCryptoStream(5)).To(BeFalse()) + } + }) + + It("says if a stream is the crypto stream, for TLS", func() { + // all streams contribute to connection-level flow control + for id := StreamID(0); id < 10; id++ { + Expect(VersionTLS.IsCryptoStream(id)).To(BeFalse()) + } }) It("tells if a version uses the IETF frame types", func() { @@ -122,10 +133,10 @@ var _ = Describe("Version", func() { }) It("says if a stream contributes to connection-level flowcontrol, for TLS", func() { - Expect(VersionTLS.StreamContributesToConnectionFlowControl(0)).To(BeFalse()) - Expect(VersionTLS.StreamContributesToConnectionFlowControl(1)).To(BeTrue()) - Expect(VersionTLS.StreamContributesToConnectionFlowControl(2)).To(BeTrue()) - Expect(VersionTLS.StreamContributesToConnectionFlowControl(3)).To(BeTrue()) + // all streams contribute to connection-level flow control + for id := StreamID(0); id < 10; id++ { + Expect(VersionTLS.StreamContributesToConnectionFlowControl(id)).To(BeTrue()) + } }) It("recognizes supported versions", func() { diff --git a/mint_utils.go b/mint_utils.go deleted file mode 100644 index 657adb56..00000000 --- a/mint_utils.go +++ /dev/null @@ -1,52 +0,0 @@ -package quic - -import ( - gocrypto "crypto" - "crypto/tls" - "crypto/x509" - "errors" - - "github.com/bifurcation/mint" - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) { - mconf := &mint.Config{ - NonBlocking: true, - CipherSuites: []mint.CipherSuite{ - mint.TLS_AES_128_GCM_SHA256, - mint.TLS_AES_256_GCM_SHA384, - }, - } - if tlsConf != nil { - mconf.ServerName = tlsConf.ServerName - mconf.InsecureSkipVerify = tlsConf.InsecureSkipVerify - mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates)) - mconf.RootCAs = tlsConf.RootCAs - mconf.VerifyPeerCertificate = tlsConf.VerifyPeerCertificate - for i, certChain := range tlsConf.Certificates { - mconf.Certificates[i] = &mint.Certificate{ - Chain: make([]*x509.Certificate, len(certChain.Certificate)), - PrivateKey: certChain.PrivateKey.(gocrypto.Signer), - } - for j, cert := range certChain.Certificate { - c, err := x509.ParseCertificate(cert) - if err != nil { - return nil, err - } - mconf.Certificates[i].Chain[j] = c - } - } - switch tlsConf.ClientAuth { - case tls.NoClientCert: - case tls.RequireAnyClientCert: - mconf.RequireClientAuth = true - default: - return nil, errors.New("mint currently only support ClientAuthType RequireAnyClientCert") - } - } - if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil { - return nil, err - } - return mconf, nil -} diff --git a/mint_utils_test.go b/mint_utils_test.go deleted file mode 100644 index 39b17c5e..00000000 --- a/mint_utils_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package quic - -import ( - "crypto/tls" - "crypto/x509" - "errors" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/testdata" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Mint Utils", func() { - Context("generating a mint.Config", func() { - It("sets non-blocking mode", func() { - mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(mintConf.NonBlocking).To(BeTrue()) - }) - - It("sets the certificate chain", func() { - tlsConf := testdata.GetTLSConfig() - mintConf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(mintConf.Certificates).ToNot(BeEmpty()) - Expect(mintConf.Certificates).To(HaveLen(len(tlsConf.Certificates))) - }) - - It("copies values from the tls.Config", func() { - verifyErr := errors.New("test err") - certPool := &x509.CertPool{} - tlsConf := &tls.Config{ - RootCAs: certPool, - ServerName: "www.example.com", - InsecureSkipVerify: true, - VerifyPeerCertificate: func(_ [][]byte, _ [][]*x509.Certificate) error { - return verifyErr - }, - } - mintConf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(mintConf.RootCAs).To(Equal(certPool)) - Expect(mintConf.ServerName).To(Equal("www.example.com")) - Expect(mintConf.InsecureSkipVerify).To(BeTrue()) - Expect(mintConf.VerifyPeerCertificate(nil, nil)).To(MatchError(verifyErr)) - }) - - It("requires client authentication", func() { - mintConf, err := tlsToMintConfig(nil, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(mintConf.RequireClientAuth).To(BeFalse()) - conf := &tls.Config{ClientAuth: tls.RequireAnyClientCert} - mintConf, err = tlsToMintConfig(conf, protocol.PerspectiveClient) - Expect(err).ToNot(HaveOccurred()) - Expect(mintConf.RequireClientAuth).To(BeTrue()) - }) - - It("rejects unsupported client auth types", func() { - conf := &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert} - _, err := tlsToMintConfig(conf, protocol.PerspectiveClient) - Expect(err).To(MatchError("mint currently only support ClientAuthType RequireAnyClientCert")) - }) - }) -}) diff --git a/mock_crypto_data_handler.go b/mock_crypto_data_handler.go new file mode 100644 index 00000000..58efb560 --- /dev/null +++ b/mock_crypto_data_handler.go @@ -0,0 +1,47 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: CryptoDataHandler) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockCryptoDataHandler is a mock of CryptoDataHandler interface +type MockCryptoDataHandler struct { + ctrl *gomock.Controller + recorder *MockCryptoDataHandlerMockRecorder +} + +// MockCryptoDataHandlerMockRecorder is the mock recorder for MockCryptoDataHandler +type MockCryptoDataHandlerMockRecorder struct { + mock *MockCryptoDataHandler +} + +// NewMockCryptoDataHandler creates a new mock instance +func NewMockCryptoDataHandler(ctrl *gomock.Controller) *MockCryptoDataHandler { + mock := &MockCryptoDataHandler{ctrl: ctrl} + mock.recorder = &MockCryptoDataHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { + return m.recorder +} + +// HandleData mocks base method +func (m *MockCryptoDataHandler) HandleData(arg0 []byte, arg1 protocol.EncryptionLevel) error { + ret := m.ctrl.Call(m, "HandleData", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleData indicates an expected call of HandleData +func (mr *MockCryptoDataHandlerMockRecorder) HandleData(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleData", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleData), arg0, arg1) +} diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index 35cccf18..66de8d0f 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -35,29 +35,52 @@ func (m *MockCryptoStream) EXPECT() *MockCryptoStreamMockRecorder { return m.recorder } -// Read mocks base method -func (m *MockCryptoStream) Read(arg0 []byte) (int, error) { - ret := m.ctrl.Call(m, "Read", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Read indicates an expected call of Read -func (mr *MockCryptoStreamMockRecorder) Read(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockCryptoStream)(nil).Read), arg0) -} - -// StreamID mocks base method -func (m *MockCryptoStream) StreamID() protocol.StreamID { - ret := m.ctrl.Call(m, "StreamID") - ret0, _ := ret[0].(protocol.StreamID) +// GetCryptoData mocks base method +func (m *MockCryptoStream) GetCryptoData() []byte { + ret := m.ctrl.Call(m, "GetCryptoData") + ret0, _ := ret[0].([]byte) return ret0 } -// StreamID indicates an expected call of StreamID -func (mr *MockCryptoStreamMockRecorder) StreamID() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockCryptoStream)(nil).StreamID)) +// GetCryptoData indicates an expected call of GetCryptoData +func (mr *MockCryptoStreamMockRecorder) GetCryptoData() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoData", reflect.TypeOf((*MockCryptoStream)(nil).GetCryptoData)) +} + +// HandleCryptoFrame mocks base method +func (m *MockCryptoStream) HandleCryptoFrame(arg0 *wire.CryptoFrame) error { + ret := m.ctrl.Call(m, "HandleCryptoFrame", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleCryptoFrame indicates an expected call of HandleCryptoFrame +func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).HandleCryptoFrame), arg0) +} + +// HasData mocks base method +func (m *MockCryptoStream) HasData() bool { + ret := m.ctrl.Call(m, "HasData") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasData indicates an expected call of HasData +func (mr *MockCryptoStreamMockRecorder) HasData() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockCryptoStream)(nil).HasData)) +} + +// PopCryptoFrame mocks base method +func (m *MockCryptoStream) PopCryptoFrame(arg0 protocol.ByteCount) *wire.CryptoFrame { + ret := m.ctrl.Call(m, "PopCryptoFrame", arg0) + ret0, _ := ret[0].(*wire.CryptoFrame) + return ret0 +} + +// PopCryptoFrame indicates an expected call of PopCryptoFrame +func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).PopCryptoFrame), arg0) } // Write mocks base method @@ -72,82 +95,3 @@ func (m *MockCryptoStream) Write(arg0 []byte) (int, error) { func (mr *MockCryptoStreamMockRecorder) Write(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0) } - -// closeForShutdown mocks base method -func (m *MockCryptoStream) closeForShutdown(arg0 error) { - m.ctrl.Call(m, "closeForShutdown", arg0) -} - -// closeForShutdown indicates an expected call of closeForShutdown -func (mr *MockCryptoStreamMockRecorder) closeForShutdown(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockCryptoStream)(nil).closeForShutdown), arg0) -} - -// getWindowUpdate mocks base method -func (m *MockCryptoStream) getWindowUpdate() protocol.ByteCount { - ret := m.ctrl.Call(m, "getWindowUpdate") - ret0, _ := ret[0].(protocol.ByteCount) - return ret0 -} - -// getWindowUpdate indicates an expected call of getWindowUpdate -func (mr *MockCryptoStreamMockRecorder) getWindowUpdate() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockCryptoStream)(nil).getWindowUpdate)) -} - -// handleMaxStreamDataFrame mocks base method -func (m *MockCryptoStream) handleMaxStreamDataFrame(arg0 *wire.MaxStreamDataFrame) { - m.ctrl.Call(m, "handleMaxStreamDataFrame", arg0) -} - -// handleMaxStreamDataFrame indicates an expected call of handleMaxStreamDataFrame -func (mr *MockCryptoStreamMockRecorder) handleMaxStreamDataFrame(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleMaxStreamDataFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleMaxStreamDataFrame), arg0) -} - -// handleStreamFrame mocks base method -func (m *MockCryptoStream) handleStreamFrame(arg0 *wire.StreamFrame) error { - ret := m.ctrl.Call(m, "handleStreamFrame", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// handleStreamFrame indicates an expected call of handleStreamFrame -func (mr *MockCryptoStreamMockRecorder) handleStreamFrame(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).handleStreamFrame), arg0) -} - -// hasData mocks base method -func (m *MockCryptoStream) hasData() bool { - ret := m.ctrl.Call(m, "hasData") - ret0, _ := ret[0].(bool) - return ret0 -} - -// hasData indicates an expected call of hasData -func (mr *MockCryptoStreamMockRecorder) hasData() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockCryptoStream)(nil).hasData)) -} - -// popStreamFrame mocks base method -func (m *MockCryptoStream) popStreamFrame(arg0 protocol.ByteCount) (*wire.StreamFrame, bool) { - ret := m.ctrl.Call(m, "popStreamFrame", arg0) - ret0, _ := ret[0].(*wire.StreamFrame) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// popStreamFrame indicates an expected call of popStreamFrame -func (mr *MockCryptoStreamMockRecorder) popStreamFrame(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockCryptoStream)(nil).popStreamFrame), arg0) -} - -// setReadOffset mocks base method -func (m *MockCryptoStream) setReadOffset(arg0 protocol.ByteCount) { - m.ctrl.Call(m, "setReadOffset", arg0) -} - -// setReadOffset indicates an expected call of setReadOffset -func (mr *MockCryptoStreamMockRecorder) setReadOffset(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setReadOffset", reflect.TypeOf((*MockCryptoStream)(nil).setReadOffset), arg0) -} diff --git a/mock_quic_aead_test.go b/mock_quic_aead_test.go index 63a2a5a7..80bb2224 100644 --- a/mock_quic_aead_test.go +++ b/mock_quic_aead_test.go @@ -59,3 +59,16 @@ func (m *MockQuicAEAD) OpenHandshake(arg0, arg1 []byte, arg2 protocol.PacketNumb func (mr *MockQuicAEADMockRecorder) OpenHandshake(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenHandshake", reflect.TypeOf((*MockQuicAEAD)(nil).OpenHandshake), arg0, arg1, arg2, arg3) } + +// OpenInitial mocks base method +func (m *MockQuicAEAD) OpenInitial(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) ([]byte, error) { + ret := m.ctrl.Call(m, "OpenInitial", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// OpenInitial indicates an expected call of OpenInitial +func (mr *MockQuicAEADMockRecorder) OpenInitial(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenInitial", reflect.TypeOf((*MockQuicAEAD)(nil).OpenInitial), arg0, arg1, arg2, arg3) +} diff --git a/mock_sealing_manager_test.go b/mock_sealing_manager_test.go index 73edfce2..b5a6bdc5 100644 --- a/mock_sealing_manager_test.go +++ b/mock_sealing_manager_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go (interfaces: SealingManager) +// Source: github.com/lucas-clemente/quic-go (interfaces: SealingManagerLegacy) // Package quic is a generated GoMock package. package quic @@ -12,31 +12,31 @@ import ( protocol "github.com/lucas-clemente/quic-go/internal/protocol" ) -// MockSealingManager is a mock of SealingManager interface -type MockSealingManager struct { +// MockSealingManagerLegacy is a mock of SealingManagerLegacy interface +type MockSealingManagerLegacy struct { ctrl *gomock.Controller - recorder *MockSealingManagerMockRecorder + recorder *MockSealingManagerLegacyMockRecorder } -// MockSealingManagerMockRecorder is the mock recorder for MockSealingManager -type MockSealingManagerMockRecorder struct { - mock *MockSealingManager +// MockSealingManagerLegacyMockRecorder is the mock recorder for MockSealingManagerLegacy +type MockSealingManagerLegacyMockRecorder struct { + mock *MockSealingManagerLegacy } -// NewMockSealingManager creates a new mock instance -func NewMockSealingManager(ctrl *gomock.Controller) *MockSealingManager { - mock := &MockSealingManager{ctrl: ctrl} - mock.recorder = &MockSealingManagerMockRecorder{mock} +// NewMockSealingManagerLegacy creates a new mock instance +func NewMockSealingManagerLegacy(ctrl *gomock.Controller) *MockSealingManagerLegacy { + mock := &MockSealingManagerLegacy{ctrl: ctrl} + mock.recorder = &MockSealingManagerLegacyMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use -func (m *MockSealingManager) EXPECT() *MockSealingManagerMockRecorder { +func (m *MockSealingManagerLegacy) EXPECT() *MockSealingManagerLegacyMockRecorder { return m.recorder } // GetSealer mocks base method -func (m *MockSealingManager) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) { +func (m *MockSealingManagerLegacy) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) { ret := m.ctrl.Call(m, "GetSealer") ret0, _ := ret[0].(protocol.EncryptionLevel) ret1, _ := ret[1].(handshake.Sealer) @@ -44,12 +44,12 @@ func (m *MockSealingManager) GetSealer() (protocol.EncryptionLevel, handshake.Se } // GetSealer indicates an expected call of GetSealer -func (mr *MockSealingManagerMockRecorder) GetSealer() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealer", reflect.TypeOf((*MockSealingManager)(nil).GetSealer)) +func (mr *MockSealingManagerLegacyMockRecorder) GetSealer() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealer", reflect.TypeOf((*MockSealingManagerLegacy)(nil).GetSealer)) } // GetSealerForCryptoStream mocks base method -func (m *MockSealingManager) GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) { +func (m *MockSealingManagerLegacy) GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) { ret := m.ctrl.Call(m, "GetSealerForCryptoStream") ret0, _ := ret[0].(protocol.EncryptionLevel) ret1, _ := ret[1].(handshake.Sealer) @@ -57,12 +57,12 @@ func (m *MockSealingManager) GetSealerForCryptoStream() (protocol.EncryptionLeve } // GetSealerForCryptoStream indicates an expected call of GetSealerForCryptoStream -func (mr *MockSealingManagerMockRecorder) GetSealerForCryptoStream() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerForCryptoStream", reflect.TypeOf((*MockSealingManager)(nil).GetSealerForCryptoStream)) +func (mr *MockSealingManagerLegacyMockRecorder) GetSealerForCryptoStream() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerForCryptoStream", reflect.TypeOf((*MockSealingManagerLegacy)(nil).GetSealerForCryptoStream)) } // GetSealerWithEncryptionLevel mocks base method -func (m *MockSealingManager) GetSealerWithEncryptionLevel(arg0 protocol.EncryptionLevel) (handshake.Sealer, error) { +func (m *MockSealingManagerLegacy) GetSealerWithEncryptionLevel(arg0 protocol.EncryptionLevel) (handshake.Sealer, error) { ret := m.ctrl.Call(m, "GetSealerWithEncryptionLevel", arg0) ret0, _ := ret[0].(handshake.Sealer) ret1, _ := ret[1].(error) @@ -70,6 +70,6 @@ func (m *MockSealingManager) GetSealerWithEncryptionLevel(arg0 protocol.Encrypti } // GetSealerWithEncryptionLevel indicates an expected call of GetSealerWithEncryptionLevel -func (mr *MockSealingManagerMockRecorder) GetSealerWithEncryptionLevel(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerWithEncryptionLevel", reflect.TypeOf((*MockSealingManager)(nil).GetSealerWithEncryptionLevel), arg0) +func (mr *MockSealingManagerLegacyMockRecorder) GetSealerWithEncryptionLevel(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSealerWithEncryptionLevel", reflect.TypeOf((*MockSealingManagerLegacy)(nil).GetSealerWithEncryptionLevel), arg0) } diff --git a/mockgen.go b/mockgen.go index c89b6541..d921a9f0 100644 --- a/mockgen.go +++ b/mockgen.go @@ -1,15 +1,16 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI" //go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI" //go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender" //go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_data_handler.go github.com/lucas-clemente/quic-go cryptoDataHandler" //go:generate sh -c "./mockgen_private.sh quic mock_frame_source_test.go github.com/lucas-clemente/quic-go frameSource" //go:generate sh -c "./mockgen_private.sh quic mock_ack_frame_source_test.go github.com/lucas-clemente/quic-go ackFrameSource" -//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager" -//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager" +//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManagerLegacy" //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" //go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer" //go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD" diff --git a/packet_packer.go b/packet_packer.go index 463a9428..afde1d91 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -61,7 +61,6 @@ func getMaxPacketSize(addr net.Addr) protocol.ByteCount { type sealingManager interface { GetSealer() (protocol.EncryptionLevel, handshake.Sealer) - GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error) } @@ -83,11 +82,13 @@ type packetPacker struct { version protocol.VersionNumber cryptoSetup sealingManager + initialStream cryptoStream + handshakeStream cryptoStream + token []byte packetNumberGenerator *packetNumberGenerator getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen - cryptoStream cryptoStream framer frameSource acks ackFrameSource @@ -101,11 +102,12 @@ var _ packer = &packetPacker{} func newPacketPacker( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + initialStream cryptoStream, + handshakeStream cryptoStream, initialPacketNumber protocol.PacketNumber, getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen, remoteAddr net.Addr, // only used for determining the max packet size token []byte, - cryptoStream cryptoStream, cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, @@ -113,11 +115,12 @@ func newPacketPacker( version protocol.VersionNumber, ) *packetPacker { return &packetPacker{ - cryptoStream: cryptoStream, cryptoSetup: cryptoSetup, token: token, destConnID: destConnID, srcConnID: srcConnID, + initialStream: initialStream, + handshakeStream: handshakeStream, perspective: perspective, version: version, framer: framer, @@ -147,6 +150,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { if ack == nil { return nil, nil } + // TODO(#1534): only pack ACKs with the right encryption level encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) frames := []wire.Frame{ack} @@ -163,7 +167,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { // For packets sent after completion of the handshake, it might happen that 2 packets have to be sent. // This can happen e.g. when a longer packet number is used in the header. func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) { - if packet.EncryptionLevel != protocol.EncryptionForwardSecure { + if packet.EncryptionLevel != protocol.Encryption1RTT { p, err := p.packHandshakeRetransmission(packet) return []*packedPacket{p}, err } @@ -180,7 +184,11 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP } var packets []*packedPacket - encLevel, sealer := p.cryptoSetup.GetSealer() + encLevel := packet.EncryptionLevel + sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel) + if err != nil { + return nil, err + } for len(controlFrames) > 0 || len(streamFrames) > 0 { var frames []wire.Frame var length protocol.ByteCount @@ -238,7 +246,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP return packets, nil } -// packHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption +// packHandshakeRetransmission retransmits a handshake packet func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) { sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel) if err != nil { @@ -275,7 +283,6 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { } encLevel, sealer := p.cryptoSetup.GetSealer() - header := p.getHeader(encLevel) headerLength, err := header.GetLength(p.version) if err != nil { @@ -317,25 +324,33 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { } func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { - if !p.cryptoStream.hasData() { + var s cryptoStream + var encLevel protocol.EncryptionLevel + if p.initialStream.HasData() { + s = p.initialStream + encLevel = protocol.EncryptionInitial + } else if p.handshakeStream.HasData() { + s = p.handshakeStream + encLevel = protocol.EncryptionHandshake + } + if s == nil { return nil, nil } - encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream() - header := p.getHeader(encLevel) - headerLength, err := header.GetLength(p.version) + hdr := p.getHeader(encLevel) + hdrLen, _ := hdr.GetLength(p.version) + sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel) if err != nil { return nil, err } - maxLen := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength - sf, _ := p.cryptoStream.popStreamFrame(maxLen) - sf.DataLenPresent = false - frames := []wire.Frame{sf} - raw, err := p.writeAndSealPacket(header, frames, sealer) + var length protocol.ByteCount + cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length) + frames := []wire.Frame{cf} + raw, err := p.writeAndSealPacket(hdr, frames, sealer) if err != nil { return nil, err } return &packedPacket{ - header: header, + header: hdr, raw: raw, frames: frames, encryptionLevel: encLevel, @@ -390,16 +405,16 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header DestConnectionID: p.destConnID, } - if encLevel != protocol.EncryptionForwardSecure { + if encLevel != protocol.Encryption1RTT { header.IsLongHeader = true header.SrcConnectionID = p.srcConnID // Set the payload len to maximum size. // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns. header.PayloadLen = p.maxPacketSize - if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient { + switch encLevel { + case protocol.EncryptionInitial: header.Type = protocol.PacketTypeInitial - header.Token = p.token - } else { + case protocol.EncryptionHandshake: header.Type = protocol.PacketTypeHandshake } } @@ -415,9 +430,14 @@ func (p *packetPacker) writeAndSealPacket( raw := *getPacketBuffer() buffer := bytes.NewBuffer(raw[:0]) + addPadding := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial && !p.hasSentPacket + // the payload length is only needed for Long Headers if header.IsLongHeader { - if header.Type == protocol.PacketTypeInitial { + if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { + header.Token = p.token + } + if addPadding { headerLen, _ := header.GetLength(p.version) header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen } else { @@ -435,7 +455,7 @@ func (p *packetPacker) writeAndSealPacket( payloadStartIndex := buffer.Len() // the Initial packet needs to be padded, so the last STREAM frame must have the data length present - if header.Type == protocol.PacketTypeInitial { + if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { lastFrame := frames[len(frames)-1] if sf, ok := lastFrame.(*wire.StreamFrame); ok { sf.DataLenPresent = true @@ -446,8 +466,7 @@ func (p *packetPacker) writeAndSealPacket( return nil, err } } - // if this is an Initial packet, we need to pad it to fulfill the minimum size requirement - if header.Type == protocol.PacketTypeInitial { + if addPadding { paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len() if paddingLen > 0 { buffer.Write(bytes.Repeat([]byte{0}, paddingLen)) @@ -471,10 +490,7 @@ func (p *packetPacker) writeAndSealPacket( } func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool { - if p.perspective == protocol.PerspectiveClient { - return encLevel >= protocol.EncryptionSecure - } - return encLevel == protocol.EncryptionForwardSecure + return encLevel == protocol.Encryption1RTT } func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) { diff --git a/packet_packer_legacy.go b/packet_packer_legacy.go index 2fb20ed5..c4c55455 100644 --- a/packet_packer_legacy.go +++ b/packet_packer_legacy.go @@ -12,6 +12,12 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type sealingManagerLegacy interface { + GetSealer() (protocol.EncryptionLevel, handshake.Sealer) + GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) + GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error) +} + // sentAndReceivedPacketManager is only needed until STOP_WAITING is removed type sentAndReceivedPacketManager struct { ackhandler.SentPacketHandler @@ -26,13 +32,13 @@ type packetPackerLegacy struct { perspective protocol.Perspective version protocol.VersionNumber - cryptoSetup sealingManager + cryptoSetup sealingManagerLegacy divNonce []byte packetNumberGenerator *packetNumberGenerator getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen - cryptoStream cryptoStream + cryptoStream streamI framer frameSource acks ackFrameSource @@ -50,8 +56,8 @@ func newPacketPackerLegacy( getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen, remoteAddr net.Addr, // only used for determining the max packet size divNonce []byte, - cryptoStream cryptoStream, - cryptoSetup sealingManager, + cryptoStream streamI, + cryptoSetup sealingManagerLegacy, framer frameSource, acks ackFrameSource, perspective protocol.Perspective, diff --git a/packet_packer_legacy_test.go b/packet_packer_legacy_test.go index 5067c53e..47eaae78 100644 --- a/packet_packer_legacy_test.go +++ b/packet_packer_legacy_test.go @@ -22,8 +22,8 @@ var _ = Describe("Packet packer (legacy)", func() { packer *packetPackerLegacy framer *MockFrameSource ackFramer *MockAckFrameSource - cryptoStream *MockCryptoStream - sealingManager *MockSealingManager + cryptoStream *MockStreamI + sealingManager *MockSealingManagerLegacy sealer *mocks.MockSealer divNonce []byte ) @@ -49,10 +49,10 @@ var _ = Describe("Packet packer (legacy)", func() { version := versionGQUICFrames mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() - cryptoStream = NewMockCryptoStream(mockCtrl) + cryptoStream = NewMockStreamI(mockCtrl) framer = NewMockFrameSource(mockCtrl) ackFramer = NewMockAckFrameSource(mockCtrl) - sealingManager = NewMockSealingManager(mockCtrl) + sealingManager = NewMockSealingManagerLegacy(mockCtrl) sealer = mocks.NewMockSealer(mockCtrl) sealer.EXPECT().Overhead().Return(9).AnyTimes() sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { @@ -544,7 +544,7 @@ var _ = Describe("Packet packer (legacy)", func() { It("sends unencrypted stream data on the crypto stream", func() { sealingManager.EXPECT().GetSealerForCryptoStream().Return(protocol.EncryptionUnencrypted, sealer) f := &wire.StreamFrame{ - StreamID: packer.version.CryptoStreamID(), + StreamID: 1, Data: []byte("foobar"), } cryptoStream.EXPECT().hasData().Return(true) @@ -558,7 +558,7 @@ var _ = Describe("Packet packer (legacy)", func() { It("sends encrypted stream data on the crypto stream", func() { sealingManager.EXPECT().GetSealerForCryptoStream().Return(protocol.EncryptionSecure, sealer) f := &wire.StreamFrame{ - StreamID: packer.version.CryptoStreamID(), + StreamID: 1, Data: []byte("foobar"), } cryptoStream.EXPECT().hasData().Return(true) diff --git a/packet_packer_test.go b/packet_packer_test.go index ed015bae..b70a77a5 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -18,13 +18,14 @@ import ( var _ = Describe("Packet packer", func() { const maxPacketSize protocol.ByteCount = 1357 var ( - packer *packetPacker - framer *MockFrameSource - ackFramer *MockAckFrameSource - cryptoStream *MockCryptoStream - sealingManager *MockSealingManager - sealer *mocks.MockSealer - token []byte + packer *packetPacker + framer *MockFrameSource + ackFramer *MockAckFrameSource + initialStream *MockCryptoStream + handshakeStream *MockCryptoStream + sealingManager *MockSealingManagerLegacy + sealer *mocks.MockSealer + token []byte ) checkPayloadLen := func(data []byte) { @@ -57,10 +58,11 @@ var _ = Describe("Packet packer", func() { version := versionIETFFrames mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() - cryptoStream = NewMockCryptoStream(mockCtrl) + initialStream = NewMockCryptoStream(mockCtrl) + handshakeStream = NewMockCryptoStream(mockCtrl) framer = NewMockFrameSource(mockCtrl) ackFramer = NewMockAckFrameSource(mockCtrl) - sealingManager = NewMockSealingManager(mockCtrl) + sealingManager = NewMockSealingManagerLegacy(mockCtrl) sealer = mocks.NewMockSealer(mockCtrl) sealer.EXPECT().Overhead().Return(7).AnyTimes() sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { @@ -72,11 +74,12 @@ var _ = Describe("Packet packer", func() { packer = newPacketPacker( protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + initialStream, + handshakeStream, 1, func(protocol.PacketNumber) protocol.PacketNumberLen { return protocol.PacketNumberLen2 }, &net.TCPAddr{}, token, // token - cryptoStream, sealingManager, framer, ackFramer, @@ -105,53 +108,9 @@ var _ = Describe("Packet packer", func() { }) }) - It("returns nil when no packet is queued", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - ackFramer.EXPECT().GetAckFrame() - cryptoStream.EXPECT().hasData() - framer.EXPECT().AppendControlFrames(nil, gomock.Any()) - framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) - p, err := packer.PackPacket() - Expect(p).To(BeNil()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("packs single packets", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame() - expectAppendControlFrames() - f := &wire.StreamFrame{ - StreamID: 5, - Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - } - expectAppendStreamFrames(f) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - b := &bytes.Buffer{} - f.Write(b, packer.version) - Expect(p.frames).To(Equal([]wire.Frame{f})) - Expect(p.raw).To(ContainSubstring(b.String())) - }) - - It("stores the encryption level a packet was sealed with", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame() - expectAppendControlFrames() - expectAppendStreamFrames(&wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - }) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - }) - Context("generating a packet header", func() { - It("uses the Long Header format for non-forward-secure packets", func() { - h := packer.getHeader(protocol.EncryptionSecure) + It("uses the Long Header format", func() { + h := packer.getHeader(protocol.EncryptionHandshake) Expect(h.IsLongHeader).To(BeTrue()) Expect(h.Version).To(Equal(packer.version)) }) @@ -172,430 +131,575 @@ var _ = Describe("Packet packer", func() { dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} dest2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} packer.ChangeDestConnectionID(dest1) - h := packer.getHeader(protocol.EncryptionUnencrypted) + h := packer.getHeader(protocol.EncryptionInitial) Expect(h.SrcConnectionID).To(Equal(srcConnID)) Expect(h.DestConnectionID).To(Equal(dest1)) packer.ChangeDestConnectionID(dest2) - h = packer.getHeader(protocol.EncryptionUnencrypted) + h = packer.getHeader(protocol.EncryptionInitial) Expect(h.SrcConnectionID).To(Equal(srcConnID)) Expect(h.DestConnectionID).To(Equal(dest2)) }) - It("uses the Short Header format for forward-secure packets", func() { - h := packer.getHeader(protocol.EncryptionForwardSecure) + It("uses the Short Header format for 1-RTT packets", func() { + h := packer.getHeader(protocol.Encryption1RTT) Expect(h.IsLongHeader).To(BeFalse()) Expect(h.PacketNumberLen).To(BeNumerically(">", 0)) }) }) - It("sets the payload length for packets containing crypto data", func() { - f := &wire.StreamFrame{ - StreamID: packer.version.CryptoStreamID(), - Offset: 0x1337, - Data: []byte("foobar"), - } - sealingManager.EXPECT().GetSealerForCryptoStream().Return(protocol.EncryptionUnencrypted, sealer) - cryptoStream.EXPECT().hasData().Return(true) - cryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - checkPayloadLen(p.raw) - }) + Context("packing normal packets", func() { + BeforeEach(func() { + initialStream.EXPECT().HasData().AnyTimes() + handshakeStream.EXPECT().HasData().AnyTimes() + }) - It("packs a CONNECTION_CLOSE", func() { - ccf := wire.ConnectionCloseFrame{ - ErrorCode: 0x1337, - ReasonPhrase: "foobar", - } - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - p, err := packer.PackConnectionClose(&ccf) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(Equal(&ccf)) - }) + It("returns nil when no packet is queued", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame() + framer.EXPECT().AppendControlFrames(nil, gomock.Any()) + framer.EXPECT().AppendStreamFrames(nil, gomock.Any()) + p, err := packer.PackPacket() + Expect(p).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) + }) - It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() { - // expect no framer.PopStreamFrames - ccf := &wire.ConnectionCloseFrame{ - ErrorCode: 0x1337, - ReasonPhrase: "foobar", - } - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - p, err := packer.PackConnectionClose(ccf) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(Equal([]wire.Frame{ccf})) - }) + It("packs single packets", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame() + expectAppendControlFrames() + f := &wire.StreamFrame{ + StreamID: 5, + Data: []byte{0xde, 0xca, 0xfb, 0xad}, + } + expectAppendStreamFrames(f) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + b := &bytes.Buffer{} + f.Write(b, packer.version) + Expect(p.frames).To(Equal([]wire.Frame{f})) + Expect(p.raw).To(ContainSubstring(b.String())) + }) - It("packs control frames", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame() - frames := []wire.Frame{&wire.RstStreamFrame{}, &wire.MaxDataFrame{}} - expectAppendControlFrames(frames...) - expectAppendStreamFrames() - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(Equal(frames)) - Expect(p.raw).NotTo(BeEmpty()) - }) + It("stores the encryption level a packet was sealed with", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame() + expectAppendControlFrames() + expectAppendStreamFrames(&wire.StreamFrame{ + StreamID: 5, + Data: []byte("foobar"), + }) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + }) - It("increases the packet number", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer).Times(2) - cryptoStream.EXPECT().hasData().Times(2) - ackFramer.EXPECT().GetAckFrame().Times(2) - expectAppendControlFrames() - expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")}) - expectAppendControlFrames() - expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("raboof")}) - p1, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p1).ToNot(BeNil()) - p2, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p2).ToNot(BeNil()) - Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber)) - }) + It("packs a single ACK", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} + ackFramer.EXPECT().GetAckFrame().Return(ack) + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + expectAppendControlFrames() + expectAppendStreamFrames() + p, err := packer.PackPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.frames[0]).To(Equal(ack)) + }) - It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { - cryptoStream.EXPECT().hasData() - packer.hasSentPacket = false - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - }) + It("packs a CONNECTION_CLOSE", func() { + ccf := wire.ConnectionCloseFrame{ + ErrorCode: 0x1337, + ReasonPhrase: "foobar", + } + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + p, err := packer.PackConnectionClose(&ccf) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + Expect(p.frames[0]).To(Equal(&ccf)) + }) - It("accounts for the space consumed by control frames", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame() - var maxSize protocol.ByteCount - gomock.InOrder( - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - maxSize = maxLen - return fs, 444 - }), - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { - Expect(maxLen).To(Equal(maxSize - 444 + 1 /* data length of the STREAM frame */)) - return nil - }), - ) - _, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - }) + It("doesn't send any other frames when sending a CONNECTION_CLOSE", func() { + // expect no framer.PopStreamFrames + ccf := &wire.ConnectionCloseFrame{ + ErrorCode: 0x1337, + ReasonPhrase: "foobar", + } + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + p, err := packer.PackConnectionClose(ccf) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{ccf})) + }) - It("only increases the packet number when there is an actual packet to send", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer).Times(2) - ackFramer.EXPECT().GetAckFrame().Times(2) - cryptoStream.EXPECT().hasData().Times(2) - expectAppendStreamFrames() - expectAppendControlFrames() - packer.packetNumberGenerator.nextToSkip = 1000 - p, err := packer.PackPacket() - Expect(p).To(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) - expectAppendControlFrames() - expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")}) - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.header.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2))) - }) + It("packs control frames", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame() + frames := []wire.Frame{&wire.RstStreamFrame{}, &wire.MaxDataFrame{}} + expectAppendControlFrames(frames...) + expectAppendStreamFrames() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(Equal(frames)) + Expect(p.raw).NotTo(BeEmpty()) + }) - Context("making ACK packets retransmittable", func() { - sendMaxNumNonRetransmittableAcks := func() { - cryptoStream.EXPECT().hasData().Times(protocol.MaxNonRetransmittableAcks) - for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) + It("increases the packet number", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2) + ackFramer.EXPECT().GetAckFrame().Times(2) + expectAppendControlFrames() + expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")}) + expectAppendControlFrames() + expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("raboof")}) + p1, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p1).ToNot(BeNil()) + p2, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p2).ToNot(BeNil()) + Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber)) + }) + + It("accounts for the space consumed by control frames", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame() + var maxSize protocol.ByteCount + gomock.InOrder( + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + maxSize = maxLen + return fs, 444 + }), + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { + Expect(maxLen).To(Equal(maxSize - 444 + 1 /* data length of the STREAM frame */)) + return nil + }), + ) + _, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + }) + + It("only increases the packet number when there is an actual packet to send", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2) + ackFramer.EXPECT().GetAckFrame().Times(2) + expectAppendStreamFrames() + expectAppendControlFrames() + packer.packetNumberGenerator.nextToSkip = 1000 + p, err := packer.PackPacket() + Expect(p).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) + expectAppendControlFrames() + expectAppendStreamFrames(&wire.StreamFrame{Data: []byte("foobar")}) + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.header.PacketNumber).To(Equal(protocol.PacketNumber(1))) + Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(2))) + }) + + Context("packing ACK packets", func() { + It("doesn't pack a packet if there's no ACK to send", func() { + ackFramer.EXPECT().GetAckFrame() + p, err := packer.MaybePackAckPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + + It("packs ACK packets", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + ackFramer.EXPECT().GetAckFrame().Return(ack) + p, err := packer.MaybePackAckPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{ack})) + }) + }) + + Context("making ACK packets retransmittable", func() { + sendMaxNumNonRetransmittableAcks := func() { + for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + expectAppendControlFrames() + expectAppendStreamFrames() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + } + } + + It("adds a PING frame when it's supposed to send a retransmittable packet", func() { + sendMaxNumNonRetransmittableAcks() + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + // make sure the next packet doesn't contain another PING + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + expectAppendControlFrames() + expectAppendStreamFrames() + p, err = packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - } - } + }) - It("adds a PING frame when it's supposed to send a retransmittable packet", func() { - sendMaxNumNonRetransmittableAcks() - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) - expectAppendControlFrames() - expectAppendStreamFrames() - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(ContainElement(&wire.PingFrame{})) - // make sure the next packet doesn't contain another PING - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) - expectAppendControlFrames() - expectAppendStreamFrames() - p, err = packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) + It("waits until there's something to send before adding a PING frame", func() { + sendMaxNumNonRetransmittableAcks() + // nothing to send + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + expectAppendControlFrames() + expectAppendStreamFrames() + ackFramer.EXPECT().GetAckFrame() + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + // now add some frame to send + expectAppendControlFrames() + expectAppendStreamFrames() + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + p, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + }) + + It("doesn't send a PING if it already sent another retransmittable frame", func() { + sendMaxNumNonRetransmittableAcks() + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame() + expectAppendStreamFrames() + expectAppendControlFrames(&wire.MaxDataFrame{}) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) + }) }) - It("waits until there's something to send before adding a PING frame", func() { - sendMaxNumNonRetransmittableAcks() - // nothing to send - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - cryptoStream.EXPECT().hasData() - expectAppendControlFrames() - expectAppendStreamFrames() - ackFramer.EXPECT().GetAckFrame() - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - // now add some frame to send - expectAppendControlFrames() - expectAppendStreamFrames() - cryptoStream.EXPECT().hasData() - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - ackFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) - p, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames).To(ContainElement(&wire.PingFrame{})) + Context("STREAM frame handling", func() { + It("does not split a STREAM frame with maximum size", func() { + ackFramer.EXPECT().GetAckFrame() + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + expectAppendControlFrames() + sf := &wire.StreamFrame{ + Offset: 1, + StreamID: 5, + DataLenPresent: true, + } + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + sf.Data = bytes.Repeat([]byte{'f'}, int(maxSize-sf.Length(packer.version))) + return []wire.Frame{sf}, sf.Length(packer.version) + }) + p, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(1)) + Expect(p.raw).To(HaveLen(int(maxPacketSize))) + Expect(p.frames[0].(*wire.StreamFrame).Data).To(HaveLen(len(sf.Data))) + Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + }) + + It("packs multiple small STREAM frames into single packet", func() { + f1 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 1"), + DataLenPresent: true, + } + f2 := &wire.StreamFrame{ + StreamID: 5, + Data: []byte("frame 2"), + DataLenPresent: true, + } + f3 := &wire.StreamFrame{ + StreamID: 3, + Data: []byte("frame 3"), + DataLenPresent: true, + } + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer) + ackFramer.EXPECT().GetAckFrame() + expectAppendControlFrames() + expectAppendStreamFrames(f1, f2, f3) + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.frames).To(HaveLen(3)) + Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) + Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) + Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) + Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + }) + + It("doesn't send unencrypted stream data on a data stream", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionInitial, sealer) + ackFramer.EXPECT().GetAckFrame() + expectAppendControlFrames() + // don't expect a call to framer.PopStreamFrames + p, err := packer.PackPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(p).To(BeNil()) + }) }) - It("doesn't send a PING if it already sent another retransmittable frame", func() { - sendMaxNumNonRetransmittableAcks() - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame() - expectAppendStreamFrames() - expectAppendControlFrames(&wire.MaxDataFrame{}) - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) + Context("retransmissions", func() { + It("retransmits a small packet", func() { + packer.packetNumberGenerator.next = 10 + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil) + frames := []wire.Frame{ + &wire.MaxDataFrame{ByteOffset: 0x1234}, + &wire.StreamFrame{StreamID: 42, Data: []byte("foobar")}, + } + packets, err := packer.PackRetransmission(&ackhandler.Packet{ + EncryptionLevel: protocol.Encryption1RTT, + Frames: frames, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(packets).To(HaveLen(1)) + p := packets[0] + Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(p.frames).To(Equal(frames)) + }) + + It("packs two packets for retransmission if the original packet contained many control frames", func() { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil) + var frames []wire.Frame + var totalLen protocol.ByteCount + // pack a bunch of control frames, such that the packet is way bigger than a single packet + for i := 0; totalLen < maxPacketSize*3/2; i++ { + f := &wire.MaxStreamDataFrame{ + StreamID: protocol.StreamID(i), + ByteOffset: protocol.ByteCount(i), + } + frames = append(frames, f) + totalLen += f.Length(packer.version) + } + packer.packetNumberGenerator.next = 10 + packets, err := packer.PackRetransmission(&ackhandler.Packet{ + EncryptionLevel: protocol.Encryption1RTT, + Frames: frames, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(packets).To(HaveLen(2)) + Expect(len(packets[0].frames) + len(packets[1].frames)).To(Equal(len(frames))) + Expect(packets[1].frames).To(Equal(frames[len(packets[0].frames):])) + // check that the first packet was filled up as far as possible: + // if the first frame (after the STOP_WAITING) was packed into the first packet, it would have overflown the MaxPacketSize + Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize)) + }) + + It("splits a STREAM frame that doesn't fit", func() { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil) + packets, err := packer.PackRetransmission(&ackhandler.Packet{ + EncryptionLevel: protocol.Encryption1RTT, + Frames: []wire.Frame{&wire.StreamFrame{ + StreamID: 42, + Offset: 1337, + Data: bytes.Repeat([]byte{'a'}, int(maxPacketSize)*3/2), + }}, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(packets).To(HaveLen(2)) + Expect(packets[0].frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(packets[1].frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + sf1 := packets[0].frames[0].(*wire.StreamFrame) + sf2 := packets[1].frames[0].(*wire.StreamFrame) + Expect(sf1.StreamID).To(Equal(protocol.StreamID(42))) + Expect(sf1.Offset).To(Equal(protocol.ByteCount(1337))) + Expect(sf1.DataLenPresent).To(BeFalse()) + Expect(sf2.StreamID).To(Equal(protocol.StreamID(42))) + Expect(sf2.Offset).To(Equal(protocol.ByteCount(1337) + sf1.DataLen())) + Expect(sf2.DataLenPresent).To(BeFalse()) + Expect(sf1.DataLen() + sf2.DataLen()).To(Equal(maxPacketSize * 3 / 2)) + Expect(packets[0].raw).To(HaveLen(int(maxPacketSize))) + }) + + It("splits STREAM frames, if necessary", func() { + for i := 0; i < 100; i++ { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil).MaxTimes(2) + sf1 := &wire.StreamFrame{ + StreamID: 42, + Offset: 1337, + Data: bytes.Repeat([]byte{'a'}, 1+int(rand.Int31n(int32(maxPacketSize*4/5)))), + } + sf2 := &wire.StreamFrame{ + StreamID: 2, + Offset: 42, + Data: bytes.Repeat([]byte{'b'}, 1+int(rand.Int31n(int32(maxPacketSize*4/5)))), + } + expectedDataLen := sf1.DataLen() + sf2.DataLen() + frames := []wire.Frame{sf1, sf2} + packets, err := packer.PackRetransmission(&ackhandler.Packet{ + EncryptionLevel: protocol.Encryption1RTT, + Frames: frames, + }) + Expect(err).ToNot(HaveOccurred()) + + if len(packets) > 1 { + Expect(packets[0].raw).To(HaveLen(int(maxPacketSize))) + } + + var dataLen protocol.ByteCount + for _, p := range packets { + for _, f := range p.frames { + dataLen += f.(*wire.StreamFrame).DataLen() + } + } + Expect(dataLen).To(Equal(expectedDataLen)) + } + }) + + It("packs two packets for retransmission if the original packet contained many STREAM frames", func() { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil) + var frames []wire.Frame + var totalLen protocol.ByteCount + // pack a bunch of control frames, such that the packet is way bigger than a single packet + for i := 0; totalLen < maxPacketSize*3/2; i++ { + f := &wire.StreamFrame{ + StreamID: protocol.StreamID(i), + Data: []byte("foobar"), + DataLenPresent: true, + } + frames = append(frames, f) + totalLen += f.Length(packer.version) + } + packets, err := packer.PackRetransmission(&ackhandler.Packet{ + EncryptionLevel: protocol.Encryption1RTT, + Frames: frames, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(packets).To(HaveLen(2)) + Expect(len(packets[0].frames) + len(packets[1].frames)).To(Equal(len(frames))) // all frames + Expect(packets[1].frames).To(Equal(frames[len(packets[0].frames):])) + // check that the first packet was filled up as far as possible: + // if the first frame was packed into the first packet, it would have overflown the MaxPacketSize + Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize-protocol.MinStreamFrameSize)) + }) + + It("correctly sets the DataLenPresent on STREAM frames", func() { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.Encryption1RTT).Return(sealer, nil) + frames := []wire.Frame{ + &wire.StreamFrame{StreamID: 4, Data: []byte("foobar"), DataLenPresent: true}, + &wire.StreamFrame{StreamID: 5, Data: []byte("barfoo")}, + } + packets, err := packer.PackRetransmission(&ackhandler.Packet{ + EncryptionLevel: protocol.Encryption1RTT, + Frames: frames, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(packets).To(HaveLen(1)) + p := packets[0] + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(p.frames[1]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + sf1 := p.frames[0].(*wire.StreamFrame) + sf2 := p.frames[1].(*wire.StreamFrame) + Expect(sf1.StreamID).To(Equal(protocol.StreamID(4))) + Expect(sf1.DataLenPresent).To(BeTrue()) + Expect(sf2.StreamID).To(Equal(protocol.StreamID(5))) + Expect(sf2.DataLenPresent).To(BeFalse()) + }) + }) + + Context("max packet size", func() { + It("sets the maximum packet size", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2) + ackFramer.EXPECT().GetAckFrame().Times(2) + var initialMaxPacketSize protocol.ByteCount + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + initialMaxPacketSize = maxLen + return nil, 0 + }) + expectAppendStreamFrames() + _, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // now reduce the maxPacketSize + packer.HandleTransportParameters(&handshake.TransportParameters{ + MaxPacketSize: maxPacketSize - 10, + }) + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + Expect(maxLen).To(Equal(initialMaxPacketSize - 10)) + return nil, 0 + }) + expectAppendStreamFrames() + _, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + }) + + It("doesn't increase the max packet size", func() { + sealingManager.EXPECT().GetSealer().Return(protocol.Encryption1RTT, sealer).Times(2) + ackFramer.EXPECT().GetAckFrame().Times(2) + var initialMaxPacketSize protocol.ByteCount + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + initialMaxPacketSize = maxLen + return nil, 0 + }) + expectAppendStreamFrames() + _, err := packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + // now try to increase the maxPacketSize + packer.HandleTransportParameters(&handshake.TransportParameters{ + MaxPacketSize: maxPacketSize + 10, + }) + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { + Expect(maxLen).To(Equal(initialMaxPacketSize)) + return nil, 0 + }) + expectAppendStreamFrames() + _, err = packer.PackPacket() + Expect(err).ToNot(HaveOccurred()) + }) }) }) - Context("STREAM frame handling", func() { - It("does not split a STREAM frame with maximum size", func() { - ackFramer.EXPECT().GetAckFrame() - cryptoStream.EXPECT().hasData() - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - expectAppendControlFrames() - sf := &wire.StreamFrame{ - Offset: 1, - StreamID: 5, - DataLenPresent: true, + Context("packing crypto packets", func() { + It("sets the payload length", func() { + f := &wire.CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), } - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - sf.Data = bytes.Repeat([]byte{'f'}, int(maxSize-sf.Length(packer.version))) - return []wire.Frame{sf}, sf.Length(packer.version) - }) + initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.raw).To(HaveLen(int(maxPacketSize))) - Expect(p.frames[0].(*wire.StreamFrame).Data).To(HaveLen(len(sf.Data))) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - }) - - It("packs multiple small STREAM frames into single packet", func() { - f1 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("frame 1"), - DataLenPresent: true, - } - f2 := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("frame 2"), - DataLenPresent: true, - } - f3 := &wire.StreamFrame{ - StreamID: 3, - Data: []byte("frame 3"), - DataLenPresent: true, - } - cryptoStream.EXPECT().hasData() - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - ackFramer.EXPECT().GetAckFrame() - expectAppendControlFrames() - expectAppendStreamFrames(f1, f2, f3) - p, err := packer.PackPacket() - Expect(p).ToNot(BeNil()) - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(3)) - Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) - Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) - Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) - Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) - Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) - Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) - }) - - It("refuses to send unencrypted stream data on a data stream", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionUnencrypted, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame() - expectAppendControlFrames() - // don't expect a call to framer.PopStreamFrames - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p).To(BeNil()) - }) - - It("sends non forward-secure data as the client", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame() - expectAppendControlFrames() - f := &wire.StreamFrame{ - StreamID: 5, - Data: []byte("foobar"), - } - expectAppendStreamFrames(f) - packer.perspective = protocol.PerspectiveClient - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) - Expect(p.frames).To(Equal([]wire.Frame{f})) - }) - - It("does not send non-forward-secure data as the server", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionSecure, sealer) - cryptoStream.EXPECT().hasData() - ackFramer.EXPECT().GetAckFrame() - expectAppendControlFrames() - // don't expect a call to framer.PopStreamFrames - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) + checkPayloadLen(p.raw) }) It("packs a maximum size crypto packet", func() { - var f *wire.StreamFrame + var f *wire.CryptoFrame packer.version = versionIETFFrames - sealingManager.EXPECT().GetSealerForCryptoStream().Return(protocol.EncryptionUnencrypted, sealer) - cryptoStream.EXPECT().hasData().Return(true) - cryptoStream.EXPECT().popStreamFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) (*wire.StreamFrame, bool) { - f = &wire.StreamFrame{ - StreamID: packer.version.CryptoStreamID(), - Offset: 0x1337, - } - f.Data = bytes.Repeat([]byte{'f'}, int(size-f.Length(packer.version))) - return f, false + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionHandshake).Return(sealer, nil) + initialStream.EXPECT().HasData() + handshakeStream.EXPECT().HasData().Return(true) + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { + f = &wire.CryptoFrame{Offset: 0x1337} + f.Data = bytes.Repeat([]byte{'f'}, int(size-f.Length(packer.version)-1)) + Expect(f.Length(packer.version)).To(Equal(size)) + return f }) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) - expectedPacketLen := packer.maxPacketSize - protocol.NonForwardSecurePacketSizeReduction + expectedPacketLen := packer.maxPacketSize Expect(p.raw).To(HaveLen(int(expectedPacketLen))) Expect(p.header.IsLongHeader).To(BeTrue()) checkPayloadLen(p.raw) }) - It("sends unencrypted stream data on the crypto stream", func() { - f := &wire.StreamFrame{ - StreamID: packer.version.CryptoStreamID(), - Data: []byte("foobar"), - } - cryptoStream.EXPECT().hasData().Return(true) - cryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) - sealingManager.EXPECT().GetSealerForCryptoStream().Return(protocol.EncryptionUnencrypted, sealer) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(Equal([]wire.Frame{f})) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - }) - - It("sends encrypted stream data on the crypto stream", func() { - f := &wire.StreamFrame{ - StreamID: packer.version.CryptoStreamID(), - Data: []byte("foobar"), - } - cryptoStream.EXPECT().hasData().Return(true) - cryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) - sealingManager.EXPECT().GetSealerForCryptoStream().Return(protocol.EncryptionSecure, sealer) - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(Equal([]wire.Frame{f})) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) - }) - - It("does not pack STREAM frames if not allowed", func() { - cryptoStream.EXPECT().hasData() - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionUnencrypted, sealer) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 10, Smallest: 1}}} - ackFramer.EXPECT().GetAckFrame().Return(ack) - expectAppendControlFrames() - // don't expect a call to framer.PopStreamFrames - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(Equal([]wire.Frame{ack})) - }) - }) - - It("packs a single ACK", func() { - cryptoStream.EXPECT().hasData() - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} - ackFramer.EXPECT().GetAckFrame().Return(ack) - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - expectAppendControlFrames() - expectAppendStreamFrames() - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p).ToNot(BeNil()) - Expect(p.frames[0]).To(Equal(ack)) - }) - - Context("retransmitting of handshake packets", func() { - sf := &wire.StreamFrame{ - StreamID: 1, - Data: []byte("foobar"), - } - - It("packs a retransmission with the right encryption level", func() { - sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted).Return(sealer, nil) - packet := &ackhandler.Packet{ - PacketType: protocol.PacketTypeHandshake, - EncryptionLevel: protocol.EncryptionUnencrypted, - Frames: []wire.Frame{sf}, - } - p, err := packer.PackRetransmission(packet) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(HaveLen(1)) - Expect(p[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) - Expect(p[0].frames).To(Equal([]wire.Frame{sf})) - Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - }) - - // this should never happen, since non forward-secure packets are limited to a size smaller than MaxPacketSize, such that it is always possible to retransmit them without splitting the StreamFrame - It("refuses to send a packet larger than MaxPacketSize", func() { - sealingManager.EXPECT().GetSealerWithEncryptionLevel(gomock.Any()).Return(sealer, nil) - packet := &ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionSecure, - Frames: []wire.Frame{ - &wire.StreamFrame{ - StreamID: 1, - Data: bytes.Repeat([]byte{'f'}, int(maxPacketSize)), - }, - }, - } - _, err := packer.PackRetransmission(packet) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("PacketPacker BUG: packet too large")) - }) - It("pads Initial packets to the required minimum packet size", func() { - f := &wire.StreamFrame{ - StreamID: packer.version.CryptoStreamID(), - Data: []byte("foobar"), + f := &wire.CryptoFrame{ + Data: []byte("foobar"), } - sealingManager.EXPECT().GetSealerForCryptoStream().Return(protocol.EncryptionUnencrypted, sealer) - cryptoStream.EXPECT().hasData().Return(true) - cryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) + initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.hasSentPacket = false packer.perspective = protocol.PerspectiveClient @@ -604,18 +708,16 @@ var _ = Describe("Packet packer", func() { Expect(packet.header.Token).To(Equal(token)) Expect(packet.raw).To(HaveLen(protocol.MinInitialPacketSize)) Expect(packet.frames).To(HaveLen(1)) - sf := packet.frames[0].(*wire.StreamFrame) - Expect(sf.Data).To(Equal([]byte("foobar"))) - Expect(sf.DataLenPresent).To(BeTrue()) + cf := packet.frames[0].(*wire.CryptoFrame) + Expect(cf.Data).To(Equal([]byte("foobar"))) }) - It("set the correct payload length for an Initial packet", func() { - sealingManager.EXPECT().GetSealerForCryptoStream().Return(protocol.EncryptionUnencrypted, sealer) - cryptoStream.EXPECT().hasData().Return(true) - cryptoStream.EXPECT().popStreamFrame(gomock.Any()).Return(&wire.StreamFrame{ - StreamID: packer.version.CryptoStreamID(), - Data: []byte("foobar"), - }, false) + It("sets the correct payload length for an Initial packet", func() { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) + initialStream.EXPECT().HasData().Return(true) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(&wire.CryptoFrame{ + Data: []byte("foobar"), + }) packer.hasSentPacket = false packer.perspective = protocol.PerspectiveClient packet, err := packer.PackPacket() @@ -623,251 +725,58 @@ var _ = Describe("Packet packer", func() { checkPayloadLen(packet.raw) }) - It("packs a retransmission for an Initial packet", func() { - sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted).Return(sealer, nil) - packer.version = versionIETFFrames - packer.perspective = protocol.PerspectiveClient - packet := &ackhandler.Packet{ - PacketType: protocol.PacketTypeInitial, - EncryptionLevel: protocol.EncryptionUnencrypted, - Frames: []wire.Frame{sf}, - } - p, err := packer.PackRetransmission(packet) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(HaveLen(1)) - Expect(p[0].frames).To(Equal([]wire.Frame{sf})) - Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p[0].header.Token).To(Equal(token)) - }) - }) + Context("retransmitions", func() { + sf := &wire.StreamFrame{Data: []byte("foobar")} - Context("retransmission of forward-secure packets", func() { - It("retransmits a small packet", func() { - packer.packetNumberGenerator.next = 10 - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - frames := []wire.Frame{ - &wire.MaxDataFrame{ByteOffset: 0x1234}, - &wire.StreamFrame{StreamID: 42, Data: []byte("foobar")}, - } - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionForwardSecure, - Frames: frames, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - Expect(p.frames).To(Equal(frames)) - }) - - It("packs two packets for retransmission if the original packet contained many control frames", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - var frames []wire.Frame - var totalLen protocol.ByteCount - // pack a bunch of control frames, such that the packet is way bigger than a single packet - for i := 0; totalLen < maxPacketSize*3/2; i++ { - f := &wire.MaxStreamDataFrame{ - StreamID: protocol.StreamID(i), - ByteOffset: protocol.ByteCount(i), + It("packs a retransmission with the right encryption level", func() { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) + packet := &ackhandler.Packet{ + PacketType: protocol.PacketTypeHandshake, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []wire.Frame{sf}, } - frames = append(frames, f) - totalLen += f.Length(packer.version) - } - packer.packetNumberGenerator.next = 10 - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionForwardSecure, - Frames: frames, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(2)) - Expect(len(packets[0].frames) + len(packets[1].frames)).To(Equal(len(frames))) - Expect(packets[1].frames).To(Equal(frames[len(packets[0].frames):])) - // check that the first packet was filled up as far as possible: - // if the first frame (after the STOP_WAITING) was packed into the first packet, it would have overflown the MaxPacketSize - Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize)) - }) - - It("splits a STREAM frame that doesn't fit", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionForwardSecure, - Frames: []wire.Frame{&wire.StreamFrame{ - StreamID: 42, - Offset: 1337, - Data: bytes.Repeat([]byte{'a'}, int(maxPacketSize)*3/2), - }}, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(2)) - Expect(packets[0].frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(packets[1].frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - sf1 := packets[0].frames[0].(*wire.StreamFrame) - sf2 := packets[1].frames[0].(*wire.StreamFrame) - Expect(sf1.StreamID).To(Equal(protocol.StreamID(42))) - Expect(sf1.Offset).To(Equal(protocol.ByteCount(1337))) - Expect(sf1.DataLenPresent).To(BeFalse()) - Expect(sf2.StreamID).To(Equal(protocol.StreamID(42))) - Expect(sf2.Offset).To(Equal(protocol.ByteCount(1337) + sf1.DataLen())) - Expect(sf2.DataLenPresent).To(BeFalse()) - Expect(sf1.DataLen() + sf2.DataLen()).To(Equal(maxPacketSize * 3 / 2)) - Expect(packets[0].raw).To(HaveLen(int(maxPacketSize))) - }) - - It("splits STREAM frames, if necessary", func() { - for i := 0; i < 100; i++ { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer).MaxTimes(2) - sf1 := &wire.StreamFrame{ - StreamID: 42, - Offset: 1337, - Data: bytes.Repeat([]byte{'a'}, 1+int(rand.Int31n(int32(maxPacketSize*4/5)))), - } - sf2 := &wire.StreamFrame{ - StreamID: 2, - Offset: 42, - Data: bytes.Repeat([]byte{'b'}, 1+int(rand.Int31n(int32(maxPacketSize*4/5)))), - } - expectedDataLen := sf1.DataLen() + sf2.DataLen() - frames := []wire.Frame{sf1, sf2} - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionForwardSecure, - Frames: frames, - }) + p, err := packer.PackRetransmission(packet) Expect(err).ToNot(HaveOccurred()) + Expect(p).To(HaveLen(1)) + Expect(p[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(p[0].frames).To(Equal([]wire.Frame{sf})) + Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionInitial)) + }) - if len(packets) > 1 { - Expect(packets[0].raw).To(HaveLen(int(maxPacketSize))) + // this should never happen, since non forward-secure packets are limited to a size smaller than MaxPacketSize, such that it is always possible to retransmit them without splitting the StreamFrame + It("refuses to send a packet larger than MaxPacketSize", func() { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(gomock.Any()).Return(sealer, nil) + packet := &ackhandler.Packet{ + EncryptionLevel: protocol.EncryptionHandshake, + Frames: []wire.Frame{ + &wire.StreamFrame{ + StreamID: 1, + Data: bytes.Repeat([]byte{'f'}, int(maxPacketSize)), + }, + }, } + _, err := packer.PackRetransmission(packet) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("PacketPacker BUG: packet too large")) + }) - var dataLen protocol.ByteCount - for _, p := range packets { - for _, f := range p.frames { - dataLen += f.(*wire.StreamFrame).DataLen() - } + It("packs a retransmission for an Initial packet", func() { + sealingManager.EXPECT().GetSealerWithEncryptionLevel(protocol.EncryptionInitial).Return(sealer, nil) + packer.version = versionIETFFrames + packer.perspective = protocol.PerspectiveClient + packet := &ackhandler.Packet{ + PacketType: protocol.PacketTypeInitial, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []wire.Frame{sf}, } - Expect(dataLen).To(Equal(expectedDataLen)) - } - }) - - It("packs two packets for retransmission if the original packet contained many STREAM frames", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - var frames []wire.Frame - var totalLen protocol.ByteCount - // pack a bunch of control frames, such that the packet is way bigger than a single packet - for i := 0; totalLen < maxPacketSize*3/2; i++ { - f := &wire.StreamFrame{ - StreamID: protocol.StreamID(i), - Data: []byte("foobar"), - DataLenPresent: true, - } - frames = append(frames, f) - totalLen += f.Length(packer.version) - } - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionForwardSecure, - Frames: frames, + p, err := packer.PackRetransmission(packet) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(HaveLen(1)) + Expect(p[0].frames).To(Equal([]wire.Frame{sf})) + Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(p[0].header.Token).To(Equal(token)) }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(2)) - Expect(len(packets[0].frames) + len(packets[1].frames)).To(Equal(len(frames))) // all frames - Expect(packets[1].frames).To(Equal(frames[len(packets[0].frames):])) - // check that the first packet was filled up as far as possible: - // if the first frame was packed into the first packet, it would have overflown the MaxPacketSize - Expect(len(packets[0].raw) + int(packets[1].frames[1].Length(packer.version))).To(BeNumerically(">", maxPacketSize-protocol.MinStreamFrameSize)) - }) - - It("correctly sets the DataLenPresent on STREAM frames", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - frames := []wire.Frame{ - &wire.StreamFrame{StreamID: 4, Data: []byte("foobar"), DataLenPresent: true}, - &wire.StreamFrame{StreamID: 5, Data: []byte("barfoo")}, - } - packets, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionForwardSecure, - Frames: frames, - }) - Expect(err).ToNot(HaveOccurred()) - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(p.frames[1]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - sf1 := p.frames[0].(*wire.StreamFrame) - sf2 := p.frames[1].(*wire.StreamFrame) - Expect(sf1.StreamID).To(Equal(protocol.StreamID(4))) - Expect(sf1.DataLenPresent).To(BeTrue()) - Expect(sf2.StreamID).To(Equal(protocol.StreamID(5))) - Expect(sf2.DataLenPresent).To(BeFalse()) - }) - }) - - Context("packing ACK packets", func() { - It("doesn't pack a packet if there's no ACK to send", func() { - ackFramer.EXPECT().GetAckFrame() - p, err := packer.MaybePackAckPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(BeNil()) - }) - - It("packs ACK packets", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} - ackFramer.EXPECT().GetAckFrame().Return(ack) - p, err := packer.MaybePackAckPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p.frames).To(Equal([]wire.Frame{ack})) - }) - }) - - Context("max packet size", func() { - It("sets the maximum packet size", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer).Times(2) - ackFramer.EXPECT().GetAckFrame().Times(2) - cryptoStream.EXPECT().hasData().AnyTimes() - var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - initialMaxPacketSize = maxLen - return nil, 0 - }) - expectAppendStreamFrames() - _, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - // now reduce the maxPacketSize - packer.HandleTransportParameters(&handshake.TransportParameters{ - MaxPacketSize: maxPacketSize - 10, - }) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(initialMaxPacketSize - 10)) - return nil, 0 - }) - expectAppendStreamFrames() - _, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't increase the max packet size", func() { - sealingManager.EXPECT().GetSealer().Return(protocol.EncryptionForwardSecure, sealer).Times(2) - ackFramer.EXPECT().GetAckFrame().Times(2) - cryptoStream.EXPECT().hasData().AnyTimes() - var initialMaxPacketSize protocol.ByteCount - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - initialMaxPacketSize = maxLen - return nil, 0 - }) - expectAppendStreamFrames() - _, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - // now try to increase the maxPacketSize - packer.HandleTransportParameters(&handshake.TransportParameters{ - MaxPacketSize: maxPacketSize + 10, - }) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any()).Do(func(_ []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { - Expect(maxLen).To(Equal(initialMaxPacketSize)) - return nil, 0 - }) - expectAppendStreamFrames() - _, err = packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) }) }) }) diff --git a/packet_unpacker.go b/packet_unpacker.go index 99497907..b7638f45 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/wire" @@ -18,8 +19,9 @@ type gQUICAEAD interface { } type quicAEAD interface { - OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) - Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) + OpenInitial(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) + OpenHandshake(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) + Open1RTT(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) } type packetUnpackerBase struct { @@ -103,12 +105,19 @@ func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []by var decrypted []byte var encryptionLevel protocol.EncryptionLevel var err error - if hdr.IsLongHeader { + switch hdr.Type { + case protocol.PacketTypeInitial: + decrypted, err = u.aead.OpenInitial(buf, data, hdr.PacketNumber, headerBinary) + encryptionLevel = protocol.EncryptionInitial + case protocol.PacketTypeHandshake: decrypted, err = u.aead.OpenHandshake(buf, data, hdr.PacketNumber, headerBinary) - encryptionLevel = protocol.EncryptionUnencrypted - } else { + encryptionLevel = protocol.EncryptionHandshake + default: + if hdr.IsLongHeader { + return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) + } decrypted, err = u.aead.Open1RTT(buf, data, hdr.PacketNumber, headerBinary) - encryptionLevel = protocol.EncryptionForwardSecure + encryptionLevel = protocol.Encryption1RTT } if err != nil { // Wrap err in quicError so that public reset is sent by session diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 1ed5b840..2c6a00ca 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -78,12 +78,22 @@ var _ = Describe("Packet Unpacker (for IETF QUIC)", func() { Expect(err).To(MatchError(qerr.MissingPayload)) }) - It("opens handshake packets", func() { + It("opens Initial packets", func() { hdr.IsLongHeader = true + hdr.Type = protocol.PacketTypeInitial + aead.EXPECT().OpenInitial(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil) + packet, err := unpacker.Unpack(hdr.Raw, hdr, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) + }) + + It("opens Handshake packets", func() { + hdr.IsLongHeader = true + hdr.Type = protocol.PacketTypeHandshake aead.EXPECT().OpenHandshake(gomock.Any(), gomock.Any(), hdr.PacketNumber, hdr.Raw).Return([]byte{0}, nil) packet, err := unpacker.Unpack(hdr.Raw, hdr, nil) Expect(err).ToNot(HaveOccurred()) - Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) + Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionHandshake)) }) It("unpacks the frames", func() { diff --git a/quic_suite_test.go b/quic_suite_test.go index fefe277f..2524a09b 100644 --- a/quic_suite_test.go +++ b/quic_suite_test.go @@ -22,9 +22,7 @@ const ( var mockCtrl *gomock.Controller var _ = BeforeSuite(func() { - Expect(versionGQUICFrames.CryptoStreamID()).To(Equal(protocol.StreamID(1))) Expect(versionGQUICFrames.UsesIETFFrameFormat()).To(BeFalse()) - Expect(versionIETFFrames.CryptoStreamID()).To(Equal(protocol.StreamID(0))) Expect(versionIETFFrames.UsesIETFFrameFormat()).To(BeTrue()) }) diff --git a/receive_stream.go b/receive_stream.go index f16174f3..e18f064a 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -164,7 +164,7 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err s.flowController.AddBytesRead(protocol.ByteCount(m)) } // increase the flow control window, if necessary - if s.streamID != s.version.CryptoStreamID() { + if !s.version.IsCryptoStream(s.streamID) { s.flowController.MaybeQueueWindowUpdate() } diff --git a/send_stream.go b/send_stream.go index 5ffd6efc..5fc04b4d 100644 --- a/send_stream.go +++ b/send_stream.go @@ -195,7 +195,7 @@ func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, boo return nil, s.finishedWriting && !s.finSent } - if s.streamID != s.version.CryptoStreamID() { + if !s.version.IsCryptoStream(s.streamID) { maxBytes = utils.MinByteCount(maxBytes, s.flowController.SendWindowSize()) } if maxBytes == 0 { diff --git a/server_session.go b/server_session.go index 51743b3a..8688cd05 100644 --- a/server_session.go +++ b/server_session.go @@ -46,7 +46,7 @@ func (s *serverSession) handlePacketImpl(p *receivedPacket) error { if hdr.IsLongHeader { switch hdr.Type { - case protocol.PacketTypeHandshake, protocol.PacketType0RTT: // 0-RTT accepted for gQUIC 44 + case protocol.PacketTypeInitial, protocol.PacketTypeHandshake, protocol.PacketType0RTT: // 0-RTT accepted for gQUIC 44 // nothing to do here. Packet will be passed to the session. default: // Note that this also drops 0-RTT packets. diff --git a/server_tls.go b/server_tls.go index 437bee2f..794fa005 100644 --- a/server_tls.go +++ b/server_tls.go @@ -6,7 +6,6 @@ import ( "errors" "net" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -21,11 +20,11 @@ type tlsSession struct { type serverTLS struct { conn net.PacketConn config *Config - mintConf *mint.Config + tlsConf *tls.Config params *handshake.TransportParameters cookieGenerator *handshake.CookieGenerator - newSession func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, *mint.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error) + newSession func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, *tls.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error) sessionRunner sessionRunner sessionChan chan<- tlsSession @@ -54,16 +53,12 @@ func newServerTLS( // TODO(#855): generate a real token StatelessResetToken: bytes.Repeat([]byte{42}, 16), } - mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer) - if err != nil { - return nil, nil, err - } sessionChan := make(chan tlsSession) s := &serverTLS{ conn: conn, config: config, - mintConf: mconf, + tlsConf: tlsConf, sessionRunner: runner, sessionChan: sessionChan, cookieGenerator: cookieGenerator, @@ -114,10 +109,6 @@ func (s *serverTLS) handleInitialImpl(p *receivedPacket) (quicSession, protocol. return nil, nil, s.sendRetry(p.remoteAddr, hdr) } - extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, hdr.Version, s.logger) - mconf := s.mintConf.Clone() - mconf.ExtensionHandler = extHandler - connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { return nil, nil, err @@ -131,7 +122,7 @@ func (s *serverTLS) handleInitialImpl(p *receivedPacket) (quicSession, protocol. connID, 1, s.config, - mconf, + s.tlsConf, s.params, s.logger, hdr.Version, diff --git a/server_tls_test.go b/server_tls_test.go index 4a2024cf..7c1387c3 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -2,9 +2,9 @@ package quic import ( "bytes" + "crypto/tls" "net" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" @@ -98,7 +98,19 @@ var _ = Describe("Stateless TLS handling", func() { data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), } run := make(chan struct{}) - server.newSession = func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, *mint.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error) { + server.newSession = func( + connection, + sessionRunner, + protocol.ConnectionID, + protocol.ConnectionID, + protocol.ConnectionID, + protocol.PacketNumber, + *Config, + *tls.Config, + *handshake.TransportParameters, + utils.Logger, + protocol.VersionNumber, + ) (quicSession, error) { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(p) sess.EXPECT().run().Do(func() { close(run) }) diff --git a/session.go b/session.go index d236c527..c8ca516a 100644 --- a/session.go +++ b/session.go @@ -10,7 +10,6 @@ import ( "sync" "time" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/ackhandler" "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/flowcontrol" @@ -46,7 +45,7 @@ type streamManager interface { } type cryptoStreamHandler interface { - HandleCryptoStream() error + RunHandshake() error ConnectionState() handshake.ConnectionState } @@ -85,14 +84,14 @@ type session struct { conn connection - streamsMap streamManager - cryptoStream cryptoStream + streamsMap streamManager rttStats *congestion.RTTStats + cryptoStreamManager *cryptoStreamManager sentPacketHandler ackhandler.SentPacketHandler receivedPacketHandler ackhandler.ReceivedPacketHandler - framer *framer + framer framer windowUpdateQueue *windowUpdateQueue connFlowController flowcontrol.ConnectionFlowController @@ -115,7 +114,9 @@ type session struct { undecryptablePackets []*receivedPacket receivedTooManyUndecrytablePacketsTime time.Time - // this channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them + clientHelloWritten <-chan struct{} + // This channel is passed to the CryptoSetup and receives the transport parameters, as soon as the peer sends them. + // Only used for gQUIC. paramsChan <-chan handshake.TransportParameters // the handshakeEvent channel is passed to the CryptoSetup. // It receives when it makes sense to try decrypting undecryptable packets. @@ -188,8 +189,14 @@ func newSession( if _, err := rand.Read(divNonce); err != nil { return nil, err } + s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) + cstr, err := s.streamsMap.GetOrOpenReceiveStream(1) + if err != nil { + return nil, err + } + cryptoStream := cstr.(streamI) cs, err := newCryptoSetup( - s.cryptoStream, + cryptoStream, srcConnID, s.conn.RemoteAddr(), s.version, @@ -208,15 +215,14 @@ func newSession( } s.cryptoStreamHandler = cs s.unpacker = newPacketUnpackerGQUIC(cs, s.version) - s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) - s.framer = newFramer(s.cryptoStream, s.streamsMap, s.version) + s.framer = newFramer(s.streamsMap, s.version) s.packer = newPacketPackerLegacy( destConnID, srcConnID, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), divNonce, - s.cryptoStream, + cryptoStream, cs, s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, @@ -264,8 +270,14 @@ var newClientSession = func( IdleTimeout: s.config.IdleTimeout, OmitConnectionID: s.config.RequestConnectionIDOmission, } + s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) + cstr, err := s.streamsMap.GetOrOpenReceiveStream(1) + if err != nil { + return nil, err + } + cryptoStream := cstr.(streamI) cs, err := newCryptoSetupClient( - s.cryptoStream, + cryptoStream, destConnID, s.version, tlsConf, @@ -282,17 +294,17 @@ var newClientSession = func( } s.cryptoStreamHandler = cs s.unpacker = newPacketUnpackerGQUIC(cs, s.version) - s.streamsMap = newStreamsMapLegacy(s.newStream, s.config.MaxIncomingStreams, s.perspective) - s.framer = newFramer(s.cryptoStream, s.streamsMap, s.version) + framer := newFramer(s.streamsMap, s.version) + s.framer = framer s.packer = newPacketPackerLegacy( destConnID, srcConnID, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), nil, // no diversification nonce - s.cryptoStream, + cryptoStream, cs, - s.framer, + framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, @@ -307,60 +319,71 @@ func newTLSServerSession( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, initialPacketNumber protocol.PacketNumber, - config *Config, - mintConf *mint.Config, - peerParams *handshake.TransportParameters, + conf *Config, + tlsConf *tls.Config, + params *handshake.TransportParameters, logger utils.Logger, v protocol.VersionNumber, ) (quicSession, error) { - handshakeEvent := make(chan struct{}, 1) + handshakeEvent := make(chan struct{}, 2) // TODO: explain cap handshakeCompleteChan := make(chan struct{}) s := &session{ conn: conn, sessionRunner: runner, - config: config, + config: conf, srcConnID: srcConnID, destConnID: destConnID, perspective: protocol.PerspectiveServer, - version: v, handshakeEvent: handshakeEvent, handshakeCompleteChan: handshakeCompleteChan, logger: logger, + version: v, } s.preSetup() + initialStream := newCryptoStream() + handshakeStream := newCryptoStream() + s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) + s.framer = newFramer(s.streamsMap, s.version) cs, err := handshake.NewCryptoSetupTLSServer( - s.cryptoStream, + initialStream, + handshakeStream, origConnID, - mintConf, + params, + s.processTransportParameters, handshakeEvent, handshakeCompleteChan, + tlsConf, + conf.Versions, v, + logger, + protocol.PerspectiveServer, ) if err != nil { return nil, err } s.cryptoStreamHandler = cs s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) - s.framer = newFramer(s.cryptoStream, s.streamsMap, s.version) + s.framer = newFramer(s.streamsMap, s.version) s.packer = newPacketPacker( s.destConnID, s.srcConnID, + initialStream, + handshakeStream, initialPacketNumber, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), nil, // no token - s.cryptoStream, cs, s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream) + if err := s.postSetup(); err != nil { return nil, err } - s.peerParams = peerParams - s.processTransportParameters(peerParams) s.unpacker = newPacketUnpacker(cs, s.version) return s, nil } @@ -373,13 +396,14 @@ var newTLSClientSession = func( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, - mintConf *mint.Config, - paramsChan <-chan handshake.TransportParameters, + tlsConf *tls.Config, + params *handshake.TransportParameters, + initialVersion protocol.VersionNumber, initialPacketNumber protocol.PacketNumber, logger utils.Logger, v protocol.VersionNumber, ) (quicSession, error) { - handshakeEvent := make(chan struct{}, 1) + handshakeEvent := make(chan struct{}, 2) // TODO: explain cap handshakeCompleteChan := make(chan struct{}) s := &session{ conn: conn, @@ -388,36 +412,47 @@ var newTLSClientSession = func( srcConnID: srcConnID, destConnID: destConnID, perspective: protocol.PerspectiveClient, - version: v, handshakeEvent: handshakeEvent, handshakeCompleteChan: handshakeCompleteChan, - paramsChan: paramsChan, logger: logger, + version: v, } s.preSetup() - cs, err := handshake.NewCryptoSetupTLSClient( - s.cryptoStream, + initialStream := newCryptoStream() + handshakeStream := newCryptoStream() + cs, clientHelloWritten, err := handshake.NewCryptoSetupTLSClient( + initialStream, + handshakeStream, s.destConnID, - mintConf, + params, + s.processTransportParameters, handshakeEvent, handshakeCompleteChan, + tlsConf, + initialVersion, + conf.Versions, v, + logger, + protocol.PerspectiveClient, ) if err != nil { return nil, err } + s.clientHelloWritten = clientHelloWritten s.cryptoStreamHandler = cs + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream) s.unpacker = newPacketUnpacker(cs, s.version) s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version) - s.framer = newFramer(s.cryptoStream, s.streamsMap, s.version) + s.framer = newFramer(s.streamsMap, s.version) s.packer = newPacketPacker( s.destConnID, s.srcConnID, + initialStream, + handshakeStream, initialPacketNumber, s.sentPacketHandler.GetPacketNumberLen, s.RemoteAddr(), token, - s.cryptoStream, cs, s.framer, sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, @@ -438,7 +473,6 @@ func (s *session) preSetup() { s.rttStats, s.logger, ) - s.cryptoStream = s.newCryptoStream() } func (s *session) postSetup() error { @@ -462,16 +496,24 @@ func (s *session) run() error { defer s.ctxCancel() go func() { - if err := s.cryptoStreamHandler.HandleCryptoStream(); err != nil { + if err := s.cryptoStreamHandler.RunHandshake(); err != nil { s.closeLocal(err) } }() + if s.version.UsesTLS() && s.perspective == protocol.PerspectiveClient { + select { + case <-s.clientHelloWritten: + s.scheduleSending() + case closeErr := <-s.closeChan: + // put the close error back into the channel, so that the run loop can receive it + s.closeChan <- closeErr + } + } var closeErr closeError runLoop: for { - // Close immediately if requested select { case closeErr = <-s.closeChan: @@ -710,6 +752,8 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve var err error wire.LogFrame(s.logger, ff, false) switch frame := ff.(type) { + case *wire.CryptoFrame: + err = s.handleCryptoFrame(frame, encLevel) case *wire.StreamFrame: err = s.handleStreamFrame(frame, encLevel) case *wire.AckFrame: @@ -759,13 +803,20 @@ func (s *session) handlePacket(p *receivedPacket) { } } +func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { + return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) +} + func (s *session) handleStreamFrame(frame *wire.StreamFrame, encLevel protocol.EncryptionLevel) error { - if frame.StreamID == s.version.CryptoStreamID() { + if s.version.IsCryptoStream(frame.StreamID) { if frame.FinBit { return errors.New("Received STREAM frame with FIN bit for the crypto stream") } - return s.cryptoStream.handleStreamFrame(frame) - } else if encLevel <= protocol.EncryptionUnencrypted { + str, _ := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) + return str.handleStreamFrame(frame) + } + + if encLevel <= protocol.EncryptionUnencrypted { return qerr.Error(qerr.UnencryptedStreamData, fmt.Sprintf("received unencrypted stream data on stream %d", frame.StreamID)) } str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) @@ -785,10 +836,6 @@ func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { } func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { - if frame.StreamID == s.version.CryptoStreamID() { - s.cryptoStream.handleMaxStreamDataFrame(frame) - return nil - } str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) if err != nil { return err @@ -806,7 +853,7 @@ func (s *session) handleMaxStreamIDFrame(frame *wire.MaxStreamIDFrame) error { } func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { - if frame.StreamID == s.version.CryptoStreamID() { + if s.version.IsCryptoStream(frame.StreamID) { return errors.New("Received RST_STREAM frame for the crypto stream") } str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) @@ -821,7 +868,7 @@ func (s *session) handleRstStreamFrame(frame *wire.RstStreamFrame) error { } func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error { - if frame.StreamID == s.version.CryptoStreamID() { + if s.version.IsCryptoStream(frame.StreamID) { return errors.New("Received a STOP_SENDING frame for the crypto stream") } str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) @@ -899,7 +946,6 @@ func (s *session) handleCloseError(closeErr closeError) error { s.logger.Errorf("Closing session with error: %s", closeErr.err.Error()) } - s.cryptoStream.closeForShutdown(quicErr) s.streamsMap.CloseWithError(quicErr) if !closeErr.sendClose { @@ -1183,22 +1229,6 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow ) } -func (s *session) newCryptoStream() cryptoStream { - id := s.version.CryptoStreamID() - flowController := flowcontrol.NewStreamFlowController( - id, - s.version.StreamContributesToConnectionFlowControl(id), - s.connFlowController, - protocol.ReceiveStreamFlowControlWindow, - protocol.ByteCount(s.config.MaxReceiveStreamFlowControlWindow), - 0, - s.onHasStreamWindowUpdate, - s.rttStats, - s.logger, - ) - return newCryptoStream(s, flowController, s.version) -} - func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { s.logger.Infof("Sending PUBLIC_RESET for connection %s, packet number %d", s.destConnID, rejectedPacketNumber) return s.conn.Write(wire.WritePublicReset(s.destConnID, rejectedPacketNumber, 0)) @@ -1253,7 +1283,7 @@ func (s *session) onHasConnectionWindowUpdate() { } func (s *session) onHasStreamData(id protocol.StreamID) { - if id != s.version.CryptoStreamID() { + if !s.version.IsCryptoStream(id) { s.framer.AddActiveStream(id) } s.scheduleSending() diff --git a/session_test.go b/session_test.go index b78bf59f..7d5fe772 100644 --- a/session_test.go +++ b/session_test.go @@ -66,7 +66,7 @@ type mockCryptoSetup struct { var _ handshake.CryptoSetup = &mockCryptoSetup{} -func (m *mockCryptoSetup) HandleCryptoStream() error { return m.handleErr } +func (m *mockCryptoSetup) RunHandshake() error { return m.handleErr } func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { panic("not implemented") } @@ -264,7 +264,7 @@ var _ = Describe("Session", func() { It("errors on a STREAM frame that would close the crypto stream", func() { err := sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: sess.version.CryptoStreamID(), + StreamID: 1, Offset: 0x1337, FinBit: true, }, protocol.EncryptionForwardSecure) @@ -273,22 +273,17 @@ var _ = Describe("Session", func() { It("accepts unencrypted STREAM frames on the crypto stream", func() { f := &wire.StreamFrame{ - StreamID: versionGQUICFrames.CryptoStreamID(), + StreamID: 1, Data: []byte("foobar"), } + str := NewMockStreamI(mockCtrl) + str.EXPECT().handleStreamFrame(f) + streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(1)).Return(str, nil) // for closed streams, the streamManager returns nil err := sess.handleStreamFrame(f, protocol.EncryptionUnencrypted) Expect(err).ToNot(HaveOccurred()) }) - It("unpacks encrypted STREAM frames on the crypto stream", func() { - err := sess.handleStreamFrame(&wire.StreamFrame{ - StreamID: versionGQUICFrames.CryptoStreamID(), - Data: []byte("foobar"), - }, protocol.EncryptionSecure) - Expect(err).ToNot(HaveOccurred()) - }) - - It("does not unpack unencrypted STREAM frames on higher streams", func() { + It("does not handle unencrypted STREAM frames on higher streams", func() { err := sess.handleStreamFrame(&wire.StreamFrame{ StreamID: 3, Data: []byte("foobar"), @@ -361,7 +356,7 @@ var _ = Describe("Session", func() { It("erros when a RST_STREAM frame would reset the crypto stream", func() { err := sess.handleRstStreamFrame(&wire.RstStreamFrame{ - StreamID: sess.version.CryptoStreamID(), + StreamID: 1, ErrorCode: 123, }) Expect(err).To(MatchError("Received RST_STREAM frame for the crypto stream")) @@ -376,18 +371,6 @@ var _ = Describe("Session", func() { sess.connFlowController = connFC }) - It("updates the flow control window of the crypto stream", func() { - fc := mocks.NewMockStreamFlowController(mockCtrl) - offset := protocol.ByteCount(0x4321) - fc.EXPECT().UpdateSendWindow(offset) - sess.cryptoStream.(*cryptoStreamImpl).sendStream.flowController = fc - err := sess.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{ - StreamID: sess.version.CryptoStreamID(), - ByteOffset: offset, - }) - Expect(err).ToNot(HaveOccurred()) - }) - It("updates the flow control window of a stream", func() { f := &wire.MaxStreamDataFrame{ StreamID: 12345, @@ -448,7 +431,7 @@ var _ = Describe("Session", func() { It("errors when receiving a STOP_SENDING for the crypto stream", func() { err := sess.handleStopSendingFrame(&wire.StopSendingFrame{ - StreamID: sess.version.CryptoStreamID(), + StreamID: 1, ErrorCode: 10, }) Expect(err).To(MatchError("Received a STOP_SENDING frame for the crypto stream")) diff --git a/streams_map.go b/streams_map.go index b9a56d6b..584a7934 100644 --- a/streams_map.go +++ b/streams_map.go @@ -48,11 +48,11 @@ func newStreamsMap( var firstOutgoingBidiStream, firstOutgoingUniStream, firstIncomingBidiStream, firstIncomingUniStream protocol.StreamID if perspective == protocol.PerspectiveServer { firstOutgoingBidiStream = 1 - firstIncomingBidiStream = 4 // the crypto stream is handled separately + firstIncomingBidiStream = 0 firstOutgoingUniStream = 3 firstIncomingUniStream = 2 } else { - firstOutgoingBidiStream = 4 // the crypto stream is handled separately + firstOutgoingBidiStream = 0 firstIncomingBidiStream = 1 firstOutgoingUniStream = 2 firstIncomingUniStream = 3 diff --git a/streams_map_legacy.go b/streams_map_legacy.go index 240eeea1..fdc713ca 100644 --- a/streams_map_legacy.go +++ b/streams_map_legacy.go @@ -17,7 +17,8 @@ type streamsMapLegacy struct { perspective protocol.Perspective - streams map[protocol.StreamID]streamI + cryptoStream streamI + streams map[protocol.StreamID]streamI nextStreamToOpen protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream() highestStreamOpenedByPeer protocol.StreamID @@ -48,6 +49,7 @@ func newStreamsMapLegacy(newStream func(protocol.StreamID) streamI, maxStreams i sm := streamsMapLegacy{ perspective: pers, streams: make(map[protocol.StreamID]streamI), + cryptoStream: newStream(1), newStream: newStream, maxIncomingStreams: maxIncomingStreams, } @@ -90,6 +92,9 @@ func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStream // getOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. func (m *streamsMapLegacy) getOrOpenStream(id protocol.StreamID) (streamI, error) { + if id == 1 { + return m.cryptoStream, nil + } m.mutex.RLock() s, ok := m.streams[id] m.mutex.RUnlock() diff --git a/streams_map_legacy_test.go b/streams_map_legacy_test.go index b9450bc0..33476a93 100644 --- a/streams_map_legacy_test.go +++ b/streams_map_legacy_test.go @@ -46,6 +46,13 @@ var _ = Describe("Streams Map (for gQUIC)", func() { setNewStreamsMap(protocol.PerspectiveServer) }) + It("gets the crypto stream", func() { + s, err := m.getOrOpenStream(1) + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + Expect(s.StreamID()).To(Equal(protocol.StreamID(1))) + }) + Context("client-side streams", func() { It("gets new streams", func() { s, err := m.getOrOpenStream(3) diff --git a/streams_map_test.go b/streams_map_test.go index 2cd58a3e..688019b6 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -30,14 +30,14 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { } serverStreamMapping := streamMapping{ - firstIncomingBidiStream: 4, + firstIncomingBidiStream: 0, firstOutgoingBidiStream: 1, firstIncomingUniStream: 2, firstOutgoingUniStream: 3, } clientStreamMapping := streamMapping{ firstIncomingBidiStream: 1, - firstOutgoingBidiStream: 4, + firstOutgoingBidiStream: 0, firstIncomingUniStream: 3, firstOutgoingUniStream: 2, } @@ -287,7 +287,7 @@ var _ = Describe("Streams Map (for IETF QUIC)", func() { MaxBidiStreams: 5, MaxUniStreams: 5, }) - Expect(m.outgoingBidiStreams.maxStream).To(Equal(protocol.StreamID(20))) + Expect(m.outgoingBidiStreams.maxStream).To(Equal(protocol.StreamID(16))) Expect(m.outgoingUniStreams.maxStream).To(Equal(protocol.StreamID(18))) }) }) diff --git a/vendor/github.com/bifurcation/mint/LICENSE.md b/vendor/github.com/bifurcation/mint/LICENSE.md deleted file mode 100644 index 63858124..00000000 --- a/vendor/github.com/bifurcation/mint/LICENSE.md +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2016 Richard Barnes - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/vendor/github.com/bifurcation/mint/README.md b/vendor/github.com/bifurcation/mint/README.md deleted file mode 100644 index 9fa05ddd..00000000 --- a/vendor/github.com/bifurcation/mint/README.md +++ /dev/null @@ -1,94 +0,0 @@ -![A lock with a mint leaf](https://ipv.sx/mint/mint.svg) - -mint - A Minimal TLS 1.3 stack -============================== - -[![Build Status](https://circleci.com/gh/bifurcation/mint.svg)](https://circleci.com/gh/bifurcation/mint) - -This project is primarily a learning effort for me to understand the [TLS -1.3](http://tlswg.github.io/tls13-spec/) protocol. The goal is to arrive at a -pretty complete implementation of TLS 1.3, with minimal, elegant code that -demonstrates how things work. Testing is a priority to ensure correctness, but -otherwise, the quality of the software engineering might not be at a level where -it makes sense to integrate this with other libraries. Backward compatibility -is not an objective. - -We borrow liberally from the [Go TLS -library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns -with earlier TLS versions. However, unnecessary parts will be ruthlessly cut -off. - -## DTLS Support - -Mint has partial support for DTLS, but that support is not yet complete -and may still contain serious defects. - - -## Quickstart - -Installation is the same as for any other Go package: - -``` -go get github.com/bifurcation/mint -``` - -The API is pretty much the same as for the TLS module, with `Dial` and `Listen` -methods wrapping the underlying socket APIs. - -``` -conn, err := mint.Dial("tcp", "localhost:4430", &mint.Config{...}) -... -listener, err := mint.Listen("tcp", "localhost:4430", &mint.Config{...}) -``` - -Documentation is available on -[godoc.org](https://godoc.org/github.com/bifurcation/mint) - - -## Interoperability testing - -The `mint-client` and `mint-server` executables are included to make it easy to -do basic interoperability tests with other TLS 1.3 implementations. The steps -for testing against NSS are as follows. - -``` -# Install mint -go get github.com/bifurcation/mint - -# Environment for NSS (you'll probably want a new directory) -NSS_ROOT= -mkdir $NSS_ROOT -cd $NSS_ROOT -export USE_64=1 -export ENABLE_TLS_1_3=1 -export HOST=localhost -export DOMSUF=localhost - -# Build NSS -hg clone https://hg.mozilla.org/projects/nss -hg clone https://hg.mozilla.org/projects/nspr -cd nss -make nss_build_all - -export PLATFORM=`cat $NSS_ROOT/dist/latest` -export DYLD_LIBRARY_PATH=$NSS_ROOT/dist/$PLATFORM/lib -export LD_LIBRARY_PATH=$NSS_ROOT/dist/$PLATFORM/lib - -# Run NSS tests (this creates data for the server to use) -cd tests/ssl_gtests -./ssl_gtests.sh - -# Test with client=mint server=NSS -cd $NSS_ROOT -./dist/$PLATFORM/bin/selfserv -d tests_results/security/$HOST.1/ssl_gtests/ -n rsa -p 4430 -# if you get `NSS_Init failed.`, check the path above, particularly around $HOST -# ... -go run $GOPATH/src/github.com/bifurcation/mint/bin/mint-client/main.go - -# Test with client=NSS server=mint -go run $GOPATH/src/github.com/bifurcation/mint/bin/mint-server/main.go -# ... -cd $NSS_ROOT -dist/$PLATFORM/bin/tstclnt -d tests_results/security/$HOST/ssl_gtests/ -V tls1.3:tls1.3 -h 127.0.0.1 -p 4430 -o -``` - diff --git a/vendor/github.com/bifurcation/mint/alert.go b/vendor/github.com/bifurcation/mint/alert.go deleted file mode 100644 index 430e4554..00000000 --- a/vendor/github.com/bifurcation/mint/alert.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package mint - -import "strconv" - -type Alert uint8 - -const ( - // alert level - AlertLevelWarning = 1 - AlertLevelError = 2 -) - -const ( - AlertCloseNotify Alert = 0 - AlertUnexpectedMessage Alert = 10 - AlertBadRecordMAC Alert = 20 - AlertDecryptionFailed Alert = 21 - AlertRecordOverflow Alert = 22 - AlertDecompressionFailure Alert = 30 - AlertHandshakeFailure Alert = 40 - AlertBadCertificate Alert = 42 - AlertUnsupportedCertificate Alert = 43 - AlertCertificateRevoked Alert = 44 - AlertCertificateExpired Alert = 45 - AlertCertificateUnknown Alert = 46 - AlertIllegalParameter Alert = 47 - AlertUnknownCA Alert = 48 - AlertAccessDenied Alert = 49 - AlertDecodeError Alert = 50 - AlertDecryptError Alert = 51 - AlertProtocolVersion Alert = 70 - AlertInsufficientSecurity Alert = 71 - AlertInternalError Alert = 80 - AlertInappropriateFallback Alert = 86 - AlertUserCanceled Alert = 90 - AlertNoRenegotiation Alert = 100 - AlertMissingExtension Alert = 109 - AlertUnsupportedExtension Alert = 110 - AlertCertificateUnobtainable Alert = 111 - AlertUnrecognizedName Alert = 112 - AlertBadCertificateStatsResponse Alert = 113 - AlertBadCertificateHashValue Alert = 114 - AlertUnknownPSKIdentity Alert = 115 - AlertNoApplicationProtocol Alert = 120 - AlertStatelessRetry Alert = 253 - AlertWouldBlock Alert = 254 - AlertNoAlert Alert = 255 -) - -var alertText = map[Alert]string{ - AlertCloseNotify: "close notify", - AlertUnexpectedMessage: "unexpected message", - AlertBadRecordMAC: "bad record MAC", - AlertDecryptionFailed: "decryption failed", - AlertRecordOverflow: "record overflow", - AlertDecompressionFailure: "decompression failure", - AlertHandshakeFailure: "handshake failure", - AlertBadCertificate: "bad certificate", - AlertUnsupportedCertificate: "unsupported certificate", - AlertCertificateRevoked: "revoked certificate", - AlertCertificateExpired: "expired certificate", - AlertCertificateUnknown: "unknown certificate", - AlertIllegalParameter: "illegal parameter", - AlertUnknownCA: "unknown certificate authority", - AlertAccessDenied: "access denied", - AlertDecodeError: "error decoding message", - AlertDecryptError: "error decrypting message", - AlertProtocolVersion: "protocol version not supported", - AlertInsufficientSecurity: "insufficient security level", - AlertInternalError: "internal error", - AlertInappropriateFallback: "inappropriate fallback", - AlertUserCanceled: "user canceled", - AlertMissingExtension: "missing extension", - AlertUnsupportedExtension: "unsupported extension", - AlertCertificateUnobtainable: "certificate unobtainable", - AlertUnrecognizedName: "unrecognized name", - AlertBadCertificateStatsResponse: "bad certificate status response", - AlertBadCertificateHashValue: "bad certificate hash value", - AlertUnknownPSKIdentity: "unknown PSK identity", - AlertNoApplicationProtocol: "no application protocol", - AlertNoRenegotiation: "no renegotiation", - AlertStatelessRetry: "stateless retry", - AlertWouldBlock: "would have blocked", - AlertNoAlert: "no alert", -} - -func (e Alert) String() string { - s, ok := alertText[e] - if ok { - return s - } - return "alert(" + strconv.Itoa(int(e)) + ")" -} - -func (e Alert) Error() string { - return e.String() -} diff --git a/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/bifurcation/mint/client-state-machine.go deleted file mode 100644 index 07e7f53f..00000000 --- a/vendor/github.com/bifurcation/mint/client-state-machine.go +++ /dev/null @@ -1,1083 +0,0 @@ -package mint - -import ( - "bytes" - "crypto" - "crypto/x509" - "hash" - "time" -) - -// Client State Machine -// -// START <----+ -// Send ClientHello | | Recv HelloRetryRequest -// / v | -// | WAIT_SH ---+ -// Can | | Recv ServerHello -// send | V -// early | WAIT_EE -// data | | Recv EncryptedExtensions -// | +--------+--------+ -// | Using | | Using certificate -// | PSK | v -// | | WAIT_CERT_CR -// | | Recv | | Recv CertificateRequest -// | | Certificate | v -// | | | WAIT_CERT -// | | | | Recv Certificate -// | | v v -// | | WAIT_CV -// | | | Recv CertificateVerify -// | +> WAIT_FINISHED <+ -// | | Recv Finished -// \ | -// | [Send EndOfEarlyData] -// | [Send Certificate [+ CertificateVerify]] -// | Send Finished -// Can send v -// app data --> CONNECTED -// after -// here -// -// State Instructions -// START Send(CH); [RekeyOut; SendEarlyData] -// WAIT_SH Send(CH) || RekeyIn -// WAIT_EE {} -// WAIT_CERT_CR {} -// WAIT_CERT {} -// WAIT_CV {} -// WAIT_FINISHED RekeyIn; [Send(EOED);] RekeyOut; [SendCert; SendCV;] SendFin; RekeyOut; -// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) - -type clientStateStart struct { - Config *Config - Opts ConnectionOptions - Params ConnectionParameters - - cookie []byte - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage - hsCtx *HandshakeContext -} - -var _ HandshakeState = &clientStateStart{} - -func (state clientStateStart) State() State { - return StateClientStart -} - -func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - // key_shares - offeredDH := map[NamedGroup][]byte{} - ks := KeyShareExtension{ - HandshakeType: HandshakeTypeClientHello, - Shares: make([]KeyShareEntry, len(state.Config.Groups)), - } - for i, group := range state.Config.Groups { - pub, priv, err := newKeyShare(group) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error generating key share [%v]", err) - return nil, nil, AlertInternalError - } - - ks.Shares[i].Group = group - ks.Shares[i].KeyExchange = pub - offeredDH[group] = priv - } - - logf(logTypeHandshake, "opts: %+v", state.Opts) - - // supported_versions, supported_groups, signature_algorithms, server_name - sv := SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello, Versions: []uint16{supportedVersion}} - sni := ServerNameExtension(state.Opts.ServerName) - sg := SupportedGroupsExtension{Groups: state.Config.Groups} - sa := SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes} - - state.Params.ServerName = state.Opts.ServerName - - // Application Layer Protocol Negotiation - var alpn *ALPNExtension - if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) { - alpn = &ALPNExtension{Protocols: state.Opts.NextProtos} - } - - // Construct base ClientHello - ch := &ClientHelloBody{ - LegacyVersion: wireVersion(state.hsCtx.hIn), - CipherSuites: state.Config.CipherSuites, - } - _, err := prng.Read(ch.Random[:]) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error creating ClientHello random [%v]", err) - return nil, nil, AlertInternalError - } - for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} { - err := ch.Extensions.Add(ext) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err) - return nil, nil, AlertInternalError - } - } - // XXX: These optional extensions can't be folded into the above because Go - // interface-typed values are never reported as nil - if alpn != nil { - err := ch.Extensions.Add(alpn) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) - return nil, nil, AlertInternalError - } - } - if state.cookie != nil { - err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie}) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Run the external extension handler. - if state.Config.ExtensionHandler != nil { - err := state.Config.ExtensionHandler.Send(HandshakeTypeClientHello, &ch.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error running external extension sender [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Handle PSK and EarlyData just before transmitting, so that we can - // calculate the PSK binder value - var psk *PreSharedKeyExtension - var ed *EarlyDataExtension - var offeredPSK PreSharedKey - var earlyHash crypto.Hash - var earlySecret []byte - var clientEarlyTrafficKeys keySet - var clientHello *HandshakeMessage - if key, ok := state.Config.PSKs.Get(state.Opts.ServerName); ok { - offeredPSK = key - - // Narrow ciphersuites to ones that match PSK hash - params, ok := cipherSuiteMap[key.CipherSuite] - if !ok { - logf(logTypeHandshake, "[ClientStateStart] PSK for unknown ciphersuite") - return nil, nil, AlertInternalError - } - - compatibleSuites := []CipherSuite{} - for _, suite := range ch.CipherSuites { - if cipherSuiteMap[suite].Hash == params.Hash { - compatibleSuites = append(compatibleSuites, suite) - } - } - ch.CipherSuites = compatibleSuites - - // TODO(ekr@rtfm.com): Check that the ticket can be used for early - // data. - // Signal early data if we're going to do it - if state.Config.AllowEarlyData && state.helloRetryRequest == nil { - state.Params.ClientSendingEarlyData = true - ed = &EarlyDataExtension{} - err = ch.Extensions.Add(ed) - if err != nil { - logf(logTypeHandshake, "Error adding early data extension: %v", err) - return nil, nil, AlertInternalError - } - } - - // Signal supported PSK key exchange modes - if len(state.Config.PSKModes) == 0 { - logf(logTypeHandshake, "PSK selected, but no PSKModes") - return nil, nil, AlertInternalError - } - kem := &PSKKeyExchangeModesExtension{KEModes: state.Config.PSKModes} - err = ch.Extensions.Add(kem) - if err != nil { - logf(logTypeHandshake, "Error adding PSKKeyExchangeModes extension: %v", err) - return nil, nil, AlertInternalError - } - - // Add the shim PSK extension to the ClientHello - logf(logTypeHandshake, "Adding PSK extension with id = %x", key.Identity) - psk = &PreSharedKeyExtension{ - HandshakeType: HandshakeTypeClientHello, - Identities: []PSKIdentity{ - { - Identity: key.Identity, - ObfuscatedTicketAge: uint32(time.Since(key.ReceivedAt)/time.Millisecond) + key.TicketAgeAdd, - }, - }, - Binders: []PSKBinderEntry{ - // Note: Stub to get the length fields right - {Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())}, - }, - } - ch.Extensions.Add(psk) - - // Compute the binder key - h0 := params.Hash.New().Sum(nil) - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - - earlyHash = params.Hash - earlySecret = HkdfExtract(params.Hash, zero, key.Key) - logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) - - binderLabel := labelExternalBinder - if key.IsResumption { - binderLabel = labelResumptionBinder - } - binderKey := deriveSecret(params, earlySecret, binderLabel, h0) - logf(logTypeCrypto, "binder key: [%d] %x", len(binderKey), binderKey) - - // Compute the binder value - trunc, err := ch.Truncated() - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error marshaling truncated ClientHello [%v]", err) - return nil, nil, AlertInternalError - } - - truncHash := params.Hash.New() - truncHash.Write(trunc) - - binder := computeFinishedData(params, binderKey, truncHash.Sum(nil)) - - // Replace the PSK extension - psk.Binders[0].Binder = binder - ch.Extensions.Add(psk) - - // If we got here, the earlier marshal succeeded (in ch.Truncated()), so - // this one should too. - clientHello, _ = state.hsCtx.hOut.HandshakeMessageFromBody(ch) - - // Compute early traffic keys - h := params.Hash.New() - h.Write(clientHello.Marshal()) - chHash := h.Sum(nil) - - earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) - logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) - clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) - } else { - clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch) - if err != nil { - logf(logTypeHandshake, "[ClientStateStart] Error marshaling ClientHello [%v]", err) - return nil, nil, AlertInternalError - } - } - - logf(logTypeHandshake, "[ClientStateStart] -> [ClientStateWaitSH]") - state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. - nextState := clientStateWaitSH{ - Config: state.Config, - Opts: state.Opts, - Params: state.Params, - hsCtx: state.hsCtx, - OfferedDH: offeredDH, - OfferedPSK: offeredPSK, - - earlySecret: earlySecret, - earlyHash: earlyHash, - - firstClientHello: state.firstClientHello, - helloRetryRequest: state.helloRetryRequest, - clientHello: clientHello, - } - - toSend := []HandshakeAction{ - QueueHandshakeMessage{clientHello}, - SendQueuedHandshake{}, - } - if state.Params.ClientSendingEarlyData { - toSend = append(toSend, []HandshakeAction{ - RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, - }...) - } - - return nextState, toSend, AlertNoAlert -} - -type clientStateWaitSH struct { - Config *Config - Opts ConnectionOptions - Params ConnectionParameters - hsCtx *HandshakeContext - OfferedDH map[NamedGroup][]byte - OfferedPSK PreSharedKey - PSK []byte - - earlySecret []byte - earlyHash crypto.Hash - - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage - clientHello *HandshakeMessage -} - -var _ HandshakeState = &clientStateWaitSH{} - -func (state clientStateWaitSH) State() State { - return StateClientWaitSH -} - -func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - - if hm == nil || hm.msgType != HandshakeTypeServerHello { - logf(logTypeHandshake, "[ClientStateWaitSH] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - sh := &ServerHelloBody{} - if _, err := sh.Unmarshal(hm.body); err != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - // Common SH/HRR processing first. - // 1. Check that sh.version is TLS 1.2 - if sh.Version != tls12Version { - logf(logTypeHandshake, "[ClientStateWaitSH] illegal legacy version [%v]", sh.Version) - return nil, nil, AlertIllegalParameter - } - - // 2. Check that it responded with a valid version. - supportedVersions := SupportedVersionsExtension{HandshakeType: HandshakeTypeServerHello} - foundSupportedVersions, err := sh.Extensions.Find(&supportedVersions) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] invalid supported_versions extension [%v]", err) - return nil, nil, AlertDecodeError - } - if !foundSupportedVersions { - logf(logTypeHandshake, "[ClientStateWaitSH] no supported_versions extension") - return nil, nil, AlertMissingExtension - } - if supportedVersions.Versions[0] != supportedVersion { - logf(logTypeHandshake, "[ClientStateWaitSH] unsupported version [%x]", supportedVersions.Versions[0]) - return nil, nil, AlertProtocolVersion - } - // 3. Check that the server provided a supported ciphersuite - supportedCipherSuite := false - for _, suite := range state.Config.CipherSuites { - supportedCipherSuite = supportedCipherSuite || (suite == sh.CipherSuite) - } - if !supportedCipherSuite { - logf(logTypeHandshake, "[ClientStateWaitSH] Unsupported ciphersuite [%04x]", sh.CipherSuite) - return nil, nil, AlertHandshakeFailure - } - - // Now check for the sentinel. - - if sh.Random == hrrRandomSentinel { - // This is actually HRR. - hrr := sh - - // Narrow the supported ciphersuites to the server-provided one - state.Config.CipherSuites = []CipherSuite{hrr.CipherSuite} - - // Handle external extensions. - if state.Config.ExtensionHandler != nil { - err := state.Config.ExtensionHandler.Receive(HandshakeTypeHelloRetryRequest, &hrr.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - // The only thing we know how to respond to in an HRR is the Cookie - // extension, so if there is either no Cookie extension or anything other - // than a Cookie extension and SupportedVersions we have to fail. - serverCookie := new(CookieExtension) - foundCookie, err := hrr.Extensions.Find(serverCookie) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitSH] Invalid server cookie extension [%v]", err) - return nil, nil, AlertDecodeError - } - if !foundCookie || len(hrr.Extensions) != 2 { - logf(logTypeHandshake, "[ClientStateWaitSH] No Cookie or extra extensions [%v] [%d]", foundCookie, len(hrr.Extensions)) - return nil, nil, AlertIllegalParameter - } - - // Hash the body into a pseudo-message - // XXX: Ignoring some errors here - params := cipherSuiteMap[hrr.CipherSuite] - h := params.Hash.New() - h.Write(state.clientHello.Marshal()) - firstClientHello := &HandshakeMessage{ - msgType: HandshakeTypeMessageHash, - body: h.Sum(nil), - } - - state.hsCtx.receivedEndOfFlight() - - // TODO(ekr@rtfm.com): Need to rekey with cleartext if we are on 0-RTT - // mode. In DTLS, we also need to bump the sequence number. - // This is a pre-existing defect in Mint. Issue #175. - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") - return clientStateStart{ - Config: state.Config, - Opts: state.Opts, - hsCtx: state.hsCtx, - cookie: serverCookie.Cookie, - firstClientHello: firstClientHello, - helloRetryRequest: hm, - }, []HandshakeAction{ResetOut{1}}, AlertNoAlert - } - - // This is SH. - // Handle external extensions. - if state.Config.ExtensionHandler != nil { - err := state.Config.ExtensionHandler.Receive(HandshakeTypeServerHello, &sh.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientWaitSH] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Do PSK or key agreement depending on extensions - serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} - serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} - - foundExts, err := sh.Extensions.Parse( - []ExtensionBody{ - &serverPSK, - &serverKeyShare, - }) - if err != nil { - logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err) - return nil, nil, AlertDecodeError - } - - if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) { - state.Params.UsingPSK = true - } - - var dhSecret []byte - if foundExts[ExtensionTypeKeyShare] { - sks := serverKeyShare.Shares[0] - priv, ok := state.OfferedDH[sks.Group] - if !ok { - logf(logTypeHandshake, "[ClientStateWaitSH] Key share for unknown group") - return nil, nil, AlertIllegalParameter - } - - state.Params.UsingDH = true - dhSecret, _ = keyAgreement(sks.Group, sks.KeyExchange, priv) - } - - suite := sh.CipherSuite - state.Params.CipherSuite = suite - - params, ok := cipherSuiteMap[suite] - if !ok { - logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", suite) - return nil, nil, AlertHandshakeFailure - } - - // Start up the handshake hash - handshakeHash := params.Hash.New() - handshakeHash.Write(state.firstClientHello.Marshal()) - handshakeHash.Write(state.helloRetryRequest.Marshal()) - handshakeHash.Write(state.clientHello.Marshal()) - handshakeHash.Write(hm.Marshal()) - - // Compute handshake secrets - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - - var earlySecret []byte - if state.Params.UsingPSK { - if params.Hash != state.earlyHash { - logf(logTypeCrypto, "Change of hash between early and normal init early=[%02x] suite=[%04x] hash=[%02x]", - state.earlyHash, suite, params.Hash) - } - - earlySecret = state.earlySecret - } else { - earlySecret = HkdfExtract(params.Hash, zero, zero) - } - - if dhSecret == nil { - dhSecret = zero - } - - h0 := params.Hash.New().Sum(nil) - h2 := handshakeHash.Sum(nil) - preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) - handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, dhSecret) - clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) - serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) - preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) - masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) - - logf(logTypeCrypto, "early secret: [%d] %x", len(earlySecret), earlySecret) - logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) - logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) - logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) - logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) - - serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") - nextState := clientStateWaitEE{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: params, - handshakeHash: handshakeHash, - masterSecret: masterSecret, - clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: serverHandshakeTrafficSecret, - } - toSend := []HandshakeAction{ - RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, - } - // We're definitely not going to have to send anything with - // early data. - if !state.Params.ClientSendingEarlyData { - toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, - KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)}) - } - - return nextState, toSend, AlertNoAlert -} - -type clientStateWaitEE struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -var _ HandshakeState = &clientStateWaitEE{} - -func (state clientStateWaitEE) State() State { - return StateClientWaitEE -} - -func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeEncryptedExtensions { - logf(logTypeHandshake, "[ClientStateWaitEE] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - ee := EncryptedExtensionsBody{} - if err := safeUnmarshal(&ee, hm.body); err != nil { - logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - // Handle external extensions. - if state.Config.ExtensionHandler != nil { - err := state.Config.ExtensionHandler.Receive(HandshakeTypeEncryptedExtensions, &ee.Extensions) - if err != nil { - logf(logTypeHandshake, "[ClientWaitStateEE] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - serverALPN := &ALPNExtension{} - serverEarlyData := &EarlyDataExtension{} - - foundExts, err := ee.Extensions.Parse( - []ExtensionBody{ - serverALPN, - serverEarlyData, - }) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err) - return nil, nil, AlertDecodeError - } - - state.Params.UsingEarlyData = foundExts[ExtensionTypeEarlyData] - - if foundExts[ExtensionTypeALPN] && len(serverALPN.Protocols) > 0 { - state.Params.NextProto = serverALPN.Protocols[0] - } - - state.handshakeHash.Write(hm.Marshal()) - - toSend := []HandshakeAction{} - - if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData { - // We didn't get 0-RTT, so rekey to handshake. - toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, - KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) - } - - if state.Params.UsingPSK { - logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") - nextState := clientStateWaitFinished{ - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - certificates: state.Config.Certificates, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, toSend, AlertNoAlert - } - - logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") - nextState := clientStateWaitCertCR{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, toSend, AlertNoAlert -} - -type clientStateWaitCertCR struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -var _ HandshakeState = &clientStateWaitCertCR{} - -func (state clientStateWaitCertCR) State() State { - return StateClientWaitCertCR -} - -func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil { - logf(logTypeHandshake, "[ClientStateWaitCertCR] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - bodyGeneric, err := hm.ToBody() - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCertCR] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - state.handshakeHash.Write(hm.Marshal()) - - switch body := bodyGeneric.(type) { - case *CertificateBody: - logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCV]") - nextState := clientStateWaitCV{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - serverCertificate: body, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert - - case *CertificateRequestBody: - // A certificate request in the handshake should have a zero-length context - if len(body.CertificateRequestContext) > 0 { - logf(logTypeHandshake, "[ClientStateWaitCertCR] Certificate request with non-empty context: %v", err) - return nil, nil, AlertIllegalParameter - } - - state.Params.UsingClientAuth = true - - logf(logTypeHandshake, "[ClientStateWaitCertCR] -> [ClientStateWaitCert]") - nextState := clientStateWaitCert{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - serverCertificateRequest: body, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert - } - - return nil, nil, AlertUnexpectedMessage -} - -type clientStateWaitCert struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - - serverCertificateRequest *CertificateRequestBody - - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -var _ HandshakeState = &clientStateWaitCert{} - -func (state clientStateWaitCert) State() State { - return StateClientWaitCert -} - -func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeCertificate { - logf(logTypeHandshake, "[ClientStateWaitCert] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - cert := &CertificateBody{} - if err := safeUnmarshal(cert, hm.body); err != nil { - logf(logTypeHandshake, "[ClientStateWaitCert] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - state.handshakeHash.Write(hm.Marshal()) - - logf(logTypeHandshake, "[ClientStateWaitCert] -> [ClientStateWaitCV]") - nextState := clientStateWaitCV{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - serverCertificate: cert, - serverCertificateRequest: state.serverCertificateRequest, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - } - return nextState, nil, AlertNoAlert -} - -type clientStateWaitCV struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - - serverCertificate *CertificateBody - serverCertificateRequest *CertificateRequestBody - - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -var _ HandshakeState = &clientStateWaitCV{} - -func (state clientStateWaitCV) State() State { - return StateClientWaitCV -} - -func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { - logf(logTypeHandshake, "[ClientStateWaitCV] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - certVerify := CertificateVerifyBody{} - if err := safeUnmarshal(&certVerify, hm.body); err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - hcv := state.handshakeHash.Sum(nil) - logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - - serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey - if err := certVerify.Verify(serverPublicKey, hcv); err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") - return nil, nil, AlertHandshakeFailure - } - - certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList)) - rawCerts := make([][]byte, len(state.serverCertificate.CertificateList)) - for i, certEntry := range state.serverCertificate.CertificateList { - certs[i] = certEntry.CertData - rawCerts[i] = certEntry.CertData.Raw - } - - var verifiedChains [][]*x509.Certificate - if !state.Config.InsecureSkipVerify { - opts := x509.VerifyOptions{ - Roots: state.Config.RootCAs, - CurrentTime: state.Config.time(), - DNSName: state.Config.ServerName, - Intermediates: x509.NewCertPool(), - } - - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - var err error - verifiedChains, err = certs[0].Verify(opts) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err) - return nil, nil, AlertBadCertificate - } - } - - if state.Config.VerifyPeerCertificate != nil { - if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err) - return nil, nil, AlertBadCertificate - } - } - - state.handshakeHash.Write(hm.Marshal()) - - logf(logTypeHandshake, "[ClientStateWaitCV] -> [ClientStateWaitFinished]") - nextState := clientStateWaitFinished{ - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - certificates: state.Config.Certificates, - serverCertificateRequest: state.serverCertificateRequest, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, - peerCertificates: certs, - verifiedChains: verifiedChains, - } - return nextState, nil, AlertNoAlert -} - -type clientStateWaitFinished struct { - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - handshakeHash hash.Hash - - certificates []*Certificate - serverCertificateRequest *CertificateRequestBody - peerCertificates []*x509.Certificate - verifiedChains [][]*x509.Certificate - - masterSecret []byte - clientHandshakeTrafficSecret []byte - serverHandshakeTrafficSecret []byte -} - -var _ HandshakeState = &clientStateWaitFinished{} - -func (state clientStateWaitFinished) State() State { - return StateClientWaitFinished -} - -func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeFinished { - logf(logTypeHandshake, "[ClientStateWaitFinished] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - // Verify server's Finished - h3 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) - logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) - - serverFinishedData := computeFinishedData(state.cryptoParams, state.serverHandshakeTrafficSecret, h3) - logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) - - fin := &FinishedBody{VerifyDataLen: len(serverFinishedData)} - if err := safeUnmarshal(fin, hm.body); err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - if !bytes.Equal(fin.VerifyData, serverFinishedData) { - logf(logTypeHandshake, "[ClientStateWaitFinished] Server's Finished failed to verify [%x] != [%x]", - fin.VerifyData, serverFinishedData) - return nil, nil, AlertHandshakeFailure - } - - // Update the handshake hash with the Finished - state.handshakeHash.Write(hm.Marshal()) - logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(hm.Marshal()), hm.Marshal()) - h4 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 4 [%d]: %x", len(h4), h4) - - // Compute traffic secrets and keys - clientTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelClientApplicationTrafficSecret, h4) - serverTrafficSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelServerApplicationTrafficSecret, h4) - logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) - logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) - - clientTrafficKeys := makeTrafficKeys(state.cryptoParams, clientTrafficSecret) - serverTrafficKeys := makeTrafficKeys(state.cryptoParams, serverTrafficSecret) - - exporterSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelExporterSecret, h4) - logf(logTypeCrypto, "client exporter secret: [%d] %x", len(exporterSecret), exporterSecret) - - // Assemble client's second flight - toSend := []HandshakeAction{} - - if state.Params.UsingEarlyData { - logf(logTypeHandshake, "Sending end of early data") - // Note: We only send EOED if the server is actually going to use the early - // data. Otherwise, it will never see it, and the transcripts will - // mismatch. - // EOED marshal is infallible - eoedm, _ := state.hsCtx.hOut.HandshakeMessageFromBody(&EndOfEarlyDataBody{}) - toSend = append(toSend, QueueHandshakeMessage{eoedm}) - - state.handshakeHash.Write(eoedm.Marshal()) - logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) - - // And then rekey to handshake - toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, - KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) - } - - if state.Params.UsingClientAuth { - // Extract constraints from certicateRequest - schemes := SignatureAlgorithmsExtension{} - gotSchemes, err := state.serverCertificateRequest.Extensions.Find(&schemes) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING invalid signature_schemes extension [%v]", err) - return nil, nil, AlertDecodeError - } - if !gotSchemes { - logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found") - return nil, nil, AlertIllegalParameter - } - - // Select a certificate - cert, certScheme, err := CertificateSelection(nil, schemes.Algorithms, state.certificates) - if err != nil { - // XXX: Signal this to the application layer? - logf(logTypeHandshake, "[ClientStateWaitFinished] WARNING no appropriate certificate found [%v]", err) - - certificate := &CertificateBody{} - certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, QueueHandshakeMessage{certm}) - state.handshakeHash.Write(certm.Marshal()) - } else { - // Create and send Certificate, CertificateVerify - certificate := &CertificateBody{ - CertificateList: make([]CertificateEntry, len(cert.Chain)), - } - for i, entry := range cert.Chain { - certificate.CertificateList[i] = CertificateEntry{CertData: entry} - } - certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, QueueHandshakeMessage{certm}) - state.handshakeHash.Write(certm.Marshal()) - - hcv := state.handshakeHash.Sum(nil) - logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - - certificateVerify := &CertificateVerifyBody{Algorithm: certScheme} - logf(logTypeHandshake, "Creating CertVerify: %04x %v", certScheme, state.cryptoParams.Hash) - - err = certificateVerify.Sign(cert.PrivateKey, hcv) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error signing CertificateVerify [%v]", err) - return nil, nil, AlertInternalError - } - certvm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificateVerify) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling CertificateVerify [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, QueueHandshakeMessage{certvm}) - state.handshakeHash.Write(certvm.Marshal()) - } - } - - // Compute the client's Finished message - h5 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) - - clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) - logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) - - fin = &FinishedBody{ - VerifyDataLen: len(clientFinishedData), - VerifyData: clientFinishedData, - } - finm, err := state.hsCtx.hOut.HandshakeMessageFromBody(fin) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling client Finished [%v]", err) - return nil, nil, AlertInternalError - } - - // Compute the resumption secret - state.handshakeHash.Write(finm.Marshal()) - h6 := state.handshakeHash.Sum(nil) - - resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) - logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) - - toSend = append(toSend, []HandshakeAction{ - QueueHandshakeMessage{finm}, - SendQueuedHandshake{}, - RekeyIn{epoch: EpochApplicationData, KeySet: serverTrafficKeys}, - RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, - }...) - - state.hsCtx.receivedEndOfFlight() - - logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") - nextState := stateConnected{ - Params: state.Params, - hsCtx: state.hsCtx, - isClient: true, - cryptoParams: state.cryptoParams, - resumptionSecret: resumptionSecret, - clientTrafficSecret: clientTrafficSecret, - serverTrafficSecret: serverTrafficSecret, - exporterSecret: exporterSecret, - peerCertificates: state.peerCertificates, - verifiedChains: state.verifiedChains, - } - return nextState, toSend, AlertNoAlert -} diff --git a/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/bifurcation/mint/common.go deleted file mode 100644 index 05af3e95..00000000 --- a/vendor/github.com/bifurcation/mint/common.go +++ /dev/null @@ -1,266 +0,0 @@ -package mint - -import ( - "fmt" - "strconv" -) - -const ( - supportedVersion uint16 = 0x7f16 // draft-22 - tls12Version uint16 = 0x0303 - tls10Version uint16 = 0x0301 - dtls12WireVersion uint16 = 0xfefd -) - -var ( - // Flags for some minor compat issues - allowWrongVersionNumber = true - allowPKCS1 = true -) - -// enum {...} ContentType; -type RecordType byte - -const ( - RecordTypeAlert RecordType = 21 - RecordTypeHandshake RecordType = 22 - RecordTypeApplicationData RecordType = 23 - RecordTypeAck RecordType = 25 -) - -// enum {...} HandshakeType; -type HandshakeType byte - -const ( - // Omitted: *_RESERVED - HandshakeTypeClientHello HandshakeType = 1 - HandshakeTypeServerHello HandshakeType = 2 - HandshakeTypeNewSessionTicket HandshakeType = 4 - HandshakeTypeEndOfEarlyData HandshakeType = 5 - HandshakeTypeHelloRetryRequest HandshakeType = 6 - HandshakeTypeEncryptedExtensions HandshakeType = 8 - HandshakeTypeCertificate HandshakeType = 11 - HandshakeTypeCertificateRequest HandshakeType = 13 - HandshakeTypeCertificateVerify HandshakeType = 15 - HandshakeTypeServerConfiguration HandshakeType = 17 - HandshakeTypeFinished HandshakeType = 20 - HandshakeTypeKeyUpdate HandshakeType = 24 - HandshakeTypeMessageHash HandshakeType = 254 -) - -var hrrRandomSentinel = [32]byte{ - 0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, - 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91, - 0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, - 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c, -} - -// uint8 CipherSuite[2]; -type CipherSuite uint16 - -const ( - // XXX: Actually TLS_NULL_WITH_NULL_NULL, but we need a way to label the zero - // value for this type so that we can detect when a field is set. - CIPHER_SUITE_UNKNOWN CipherSuite = 0x0000 - TLS_AES_128_GCM_SHA256 CipherSuite = 0x1301 - TLS_AES_256_GCM_SHA384 CipherSuite = 0x1302 - TLS_CHACHA20_POLY1305_SHA256 CipherSuite = 0x1303 - TLS_AES_128_CCM_SHA256 CipherSuite = 0x1304 - TLS_AES_256_CCM_8_SHA256 CipherSuite = 0x1305 -) - -func (c CipherSuite) String() string { - switch c { - case CIPHER_SUITE_UNKNOWN: - return "unknown" - case TLS_AES_128_GCM_SHA256: - return "TLS_AES_128_GCM_SHA256" - case TLS_AES_256_GCM_SHA384: - return "TLS_AES_256_GCM_SHA384" - case TLS_CHACHA20_POLY1305_SHA256: - return "TLS_CHACHA20_POLY1305_SHA256" - case TLS_AES_128_CCM_SHA256: - return "TLS_AES_128_CCM_SHA256" - case TLS_AES_256_CCM_8_SHA256: - return "TLS_AES_256_CCM_8_SHA256" - } - // cannot use %x here, since it calls String(), leading to infinite recursion - return fmt.Sprintf("invalid CipherSuite value: 0x%s", strconv.FormatUint(uint64(c), 16)) -} - -// enum {...} SignatureScheme -type SignatureScheme uint16 - -const ( - // RSASSA-PKCS1-v1_5 algorithms - RSA_PKCS1_SHA1 SignatureScheme = 0x0201 - RSA_PKCS1_SHA256 SignatureScheme = 0x0401 - RSA_PKCS1_SHA384 SignatureScheme = 0x0501 - RSA_PKCS1_SHA512 SignatureScheme = 0x0601 - // ECDSA algorithms - ECDSA_P256_SHA256 SignatureScheme = 0x0403 - ECDSA_P384_SHA384 SignatureScheme = 0x0503 - ECDSA_P521_SHA512 SignatureScheme = 0x0603 - // RSASSA-PSS algorithms - RSA_PSS_SHA256 SignatureScheme = 0x0804 - RSA_PSS_SHA384 SignatureScheme = 0x0805 - RSA_PSS_SHA512 SignatureScheme = 0x0806 - // EdDSA algorithms - Ed25519 SignatureScheme = 0x0807 - Ed448 SignatureScheme = 0x0808 -) - -// enum {...} ExtensionType -type ExtensionType uint16 - -const ( - ExtensionTypeServerName ExtensionType = 0 - ExtensionTypeSupportedGroups ExtensionType = 10 - ExtensionTypeSignatureAlgorithms ExtensionType = 13 - ExtensionTypeALPN ExtensionType = 16 - ExtensionTypeKeyShare ExtensionType = 40 - ExtensionTypePreSharedKey ExtensionType = 41 - ExtensionTypeEarlyData ExtensionType = 42 - ExtensionTypeSupportedVersions ExtensionType = 43 - ExtensionTypeCookie ExtensionType = 44 - ExtensionTypePSKKeyExchangeModes ExtensionType = 45 - ExtensionTypeTicketEarlyDataInfo ExtensionType = 46 -) - -// enum {...} NamedGroup -type NamedGroup uint16 - -const ( - // Elliptic Curve Groups. - P256 NamedGroup = 23 - P384 NamedGroup = 24 - P521 NamedGroup = 25 - // ECDH functions. - X25519 NamedGroup = 29 - X448 NamedGroup = 30 - // Finite field groups. - FFDHE2048 NamedGroup = 256 - FFDHE3072 NamedGroup = 257 - FFDHE4096 NamedGroup = 258 - FFDHE6144 NamedGroup = 259 - FFDHE8192 NamedGroup = 260 -) - -// enum {...} PskKeyExchangeMode; -type PSKKeyExchangeMode uint8 - -const ( - PSKModeKE PSKKeyExchangeMode = 0 - PSKModeDHEKE PSKKeyExchangeMode = 1 -) - -// enum { -// update_not_requested(0), update_requested(1), (255) -// } KeyUpdateRequest; -type KeyUpdateRequest uint8 - -const ( - KeyUpdateNotRequested KeyUpdateRequest = 0 - KeyUpdateRequested KeyUpdateRequest = 1 -) - -type State uint8 - -const ( - StateInit = 0 - - // states valid for the client - StateClientStart State = iota - StateClientWaitSH - StateClientWaitEE - StateClientWaitCert - StateClientWaitCV - StateClientWaitFinished - StateClientWaitCertCR - StateClientConnected - // states valid for the server - StateServerStart State = iota - StateServerRecvdCH - StateServerNegotiated - StateServerReadPastEarlyData - StateServerWaitEOED - StateServerWaitFlight2 - StateServerWaitCert - StateServerWaitCV - StateServerWaitFinished - StateServerConnected -) - -func (s State) String() string { - switch s { - case StateClientStart: - return "Client START" - case StateClientWaitSH: - return "Client WAIT_SH" - case StateClientWaitEE: - return "Client WAIT_EE" - case StateClientWaitCert: - return "Client WAIT_CERT" - case StateClientWaitCV: - return "Client WAIT_CV" - case StateClientWaitFinished: - return "Client WAIT_FINISHED" - case StateClientWaitCertCR: - return "Client WAIT_CERT_CR" - case StateClientConnected: - return "Client CONNECTED" - case StateServerStart: - return "Server START" - case StateServerRecvdCH: - return "Server RECVD_CH" - case StateServerNegotiated: - return "Server NEGOTIATED" - case StateServerReadPastEarlyData: - return "Server READ_PAST_EARLY_DATA" - case StateServerWaitEOED: - return "Server WAIT_EOED" - case StateServerWaitFlight2: - return "Server WAIT_FLIGHT2" - case StateServerWaitCert: - return "Server WAIT_CERT" - case StateServerWaitCV: - return "Server WAIT_CV" - case StateServerWaitFinished: - return "Server WAIT_FINISHED" - case StateServerConnected: - return "Server CONNECTED" - default: - return fmt.Sprintf("unknown state: %d", s) - } -} - -// Epochs for DTLS (also used for key phase labelling) -type Epoch uint16 - -const ( - EpochClear Epoch = 0 - EpochEarlyData Epoch = 1 - EpochHandshakeData Epoch = 2 - EpochApplicationData Epoch = 3 - EpochUpdate Epoch = 4 -) - -func (e Epoch) label() string { - switch e { - case EpochClear: - return "clear" - case EpochEarlyData: - return "early data" - case EpochHandshakeData: - return "handshake" - case EpochApplicationData: - return "application data" - } - return "Application data (updated)" -} - -func assert(b bool) { - if !b { - panic("Assertion failed") - } -} diff --git a/vendor/github.com/bifurcation/mint/conn.go b/vendor/github.com/bifurcation/mint/conn.go deleted file mode 100644 index 12a99171..00000000 --- a/vendor/github.com/bifurcation/mint/conn.go +++ /dev/null @@ -1,921 +0,0 @@ -package mint - -import ( - "crypto" - "crypto/x509" - "encoding/hex" - "errors" - "fmt" - "io" - "net" - "reflect" - "sync" - "time" -) - -type Certificate struct { - Chain []*x509.Certificate - PrivateKey crypto.Signer -} - -type PreSharedKey struct { - CipherSuite CipherSuite - IsResumption bool - Identity []byte - Key []byte - NextProto string - ReceivedAt time.Time - ExpiresAt time.Time - TicketAgeAdd uint32 -} - -type PreSharedKeyCache interface { - Get(string) (PreSharedKey, bool) - Put(string, PreSharedKey) - Size() int -} - -// A CookieHandler can be used to give the application more fine-grained control over Cookies. -// Generate receives the Conn as an argument, so the CookieHandler can decide when to send the cookie based on that, and offload state to the client by encoding that into the Cookie. -// When the client echoes the Cookie, Validate is called. The application can then recover the state from the cookie. -type CookieHandler interface { - // Generate a byte string that is sent as a part of a cookie to the client in the HelloRetryRequest - // If Generate returns nil, mint will not send a HelloRetryRequest. - Generate(*Conn) ([]byte, error) - // Validate is called when receiving a ClientHello containing a Cookie. - // If validation failed, the handshake is aborted. - Validate(*Conn, []byte) bool -} - -type PSKMapCache map[string]PreSharedKey - -func (cache PSKMapCache) Get(key string) (psk PreSharedKey, ok bool) { - psk, ok = cache[key] - return -} - -func (cache *PSKMapCache) Put(key string, psk PreSharedKey) { - (*cache)[key] = psk -} - -func (cache PSKMapCache) Size() int { - return len(cache) -} - -// Config is the struct used to pass configuration settings to a TLS client or -// server instance. The settings for client and server are pretty different, -// but we just throw them all in here. -type Config struct { - // Client fields - ServerName string - - // Server fields - SendSessionTickets bool - TicketLifetime uint32 - TicketLen int - EarlyDataLifetime uint32 - AllowEarlyData bool - // Require the client to echo a cookie. - RequireCookie bool - // A CookieHandler can be used to set and validate a cookie. - // The cookie returned by the CookieHandler will be part of the cookie sent on the wire, and encoded using the CookieProtector. - // If no CookieHandler is set, mint will always send a cookie. - // The CookieHandler can be used to decide on a per-connection basis, if a cookie should be sent. - CookieHandler CookieHandler - // The CookieProtector is used to encrypt / decrypt cookies. - // It should make sure that the Cookie cannot be read and tampered with by the client. - // If non-blocking mode is used, and cookies are required, this field has to be set. - // In blocking mode, a default cookie protector is used, if this is unused. - CookieProtector CookieProtector - // The ExtensionHandler is used to add custom extensions. - ExtensionHandler AppExtensionHandler - RequireClientAuth bool - - // Time returns the current time as the number of seconds since the epoch. - // If Time is nil, TLS uses time.Now. - Time func() time.Time - // RootCAs defines the set of root certificate authorities - // that clients use when verifying server certificates. - // If RootCAs is nil, TLS uses the host's root CA set. - RootCAs *x509.CertPool - // InsecureSkipVerify controls whether a client verifies the - // server's certificate chain and host name. - // If InsecureSkipVerify is true, TLS accepts any certificate - // presented by the server and any host name in that certificate. - // In this mode, TLS is susceptible to man-in-the-middle attacks. - // This should be used only for testing. - InsecureSkipVerify bool - - // Shared fields - Certificates []*Certificate - // VerifyPeerCertificate, if not nil, is called after normal - // certificate verification by either a TLS client or server. It - // receives the raw ASN.1 certificates provided by the peer and also - // any verified chains that normal processing found. If it returns a - // non-nil error, the handshake is aborted and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. If normal verification is disabled by - // setting InsecureSkipVerify then this callback will be considered but - // the verifiedChains argument will always be nil. - VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - - CipherSuites []CipherSuite - Groups []NamedGroup - SignatureSchemes []SignatureScheme - NextProtos []string - PSKs PreSharedKeyCache - PSKModes []PSKKeyExchangeMode - NonBlocking bool - UseDTLS bool - - // The same config object can be shared among different connections, so it - // needs its own mutex - mutex sync.RWMutex -} - -// Clone returns a shallow clone of c. It is safe to clone a Config that is -// being used concurrently by a TLS client or server. -func (c *Config) Clone() *Config { - c.mutex.Lock() - defer c.mutex.Unlock() - - return &Config{ - ServerName: c.ServerName, - - SendSessionTickets: c.SendSessionTickets, - TicketLifetime: c.TicketLifetime, - TicketLen: c.TicketLen, - EarlyDataLifetime: c.EarlyDataLifetime, - AllowEarlyData: c.AllowEarlyData, - RequireCookie: c.RequireCookie, - CookieHandler: c.CookieHandler, - CookieProtector: c.CookieProtector, - ExtensionHandler: c.ExtensionHandler, - RequireClientAuth: c.RequireClientAuth, - Time: c.Time, - RootCAs: c.RootCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - - Certificates: c.Certificates, - VerifyPeerCertificate: c.VerifyPeerCertificate, - CipherSuites: c.CipherSuites, - Groups: c.Groups, - SignatureSchemes: c.SignatureSchemes, - NextProtos: c.NextProtos, - PSKs: c.PSKs, - PSKModes: c.PSKModes, - NonBlocking: c.NonBlocking, - UseDTLS: c.UseDTLS, - } -} - -func (c *Config) Init(isClient bool) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - // Set defaults - if len(c.CipherSuites) == 0 { - c.CipherSuites = defaultSupportedCipherSuites - } - if len(c.Groups) == 0 { - c.Groups = defaultSupportedGroups - } - if len(c.SignatureSchemes) == 0 { - c.SignatureSchemes = defaultSignatureSchemes - } - if c.TicketLen == 0 { - c.TicketLen = defaultTicketLen - } - if !reflect.ValueOf(c.PSKs).IsValid() { - c.PSKs = &PSKMapCache{} - } - if len(c.PSKModes) == 0 { - c.PSKModes = defaultPSKModes - } - return nil -} - -func (c *Config) ValidForServer() bool { - return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) || - (len(c.Certificates) > 0 && - len(c.Certificates[0].Chain) > 0 && - c.Certificates[0].PrivateKey != nil) -} - -func (c *Config) ValidForClient() bool { - return len(c.ServerName) > 0 -} - -func (c *Config) time() time.Time { - t := c.Time - if t == nil { - t = time.Now - } - return t() -} - -var ( - defaultSupportedCipherSuites = []CipherSuite{ - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, - } - - defaultSupportedGroups = []NamedGroup{ - P256, - P384, - FFDHE2048, - X25519, - } - - defaultSignatureSchemes = []SignatureScheme{ - RSA_PSS_SHA256, - RSA_PSS_SHA384, - RSA_PSS_SHA512, - ECDSA_P256_SHA256, - ECDSA_P384_SHA384, - ECDSA_P521_SHA512, - } - - defaultTicketLen = 16 - - defaultPSKModes = []PSKKeyExchangeMode{ - PSKModeKE, - PSKModeDHEKE, - } -) - -type ConnectionState struct { - HandshakeState State - CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) - PeerCertificates []*x509.Certificate // certificate chain presented by remote peer - VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates - NextProto string // Selected ALPN proto - UsingPSK bool // Are we using PSK. - UsingEarlyData bool // Did we negotiate 0-RTT. -} - -// Conn implements the net.Conn interface, as with "crypto/tls" -// * Read, Write, and Close are provided locally -// * LocalAddr, RemoteAddr, and Set*Deadline are forwarded to the inner Conn -type Conn struct { - config *Config - conn net.Conn - isClient bool - - state stateConnected - hState HandshakeState - handshakeMutex sync.Mutex - handshakeAlert Alert - handshakeComplete bool - - readBuffer []byte - in, out *RecordLayer - hsCtx *HandshakeContext -} - -func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { - c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}} - if !config.UseDTLS { - c.in = NewRecordLayerTLS(c.conn, directionRead) - c.out = NewRecordLayerTLS(c.conn, directionWrite) - c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in) - c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out) - } else { - c.in = NewRecordLayerDTLS(c.conn, directionRead) - c.out = NewRecordLayerDTLS(c.conn, directionWrite) - c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in) - c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out) - c.hsCtx.timeoutMS = initialTimeout - c.hsCtx.timers = newTimerSet() - c.hsCtx.waitingNextFlight = true - } - c.in.label = c.label() - c.out.label = c.label() - c.hsCtx.hIn.nonblocking = c.config.NonBlocking - return c -} - -// Read up -func (c *Conn) consumeRecord() error { - pt, err := c.in.ReadRecord() - if pt == nil { - logf(logTypeIO, "extendBuffer returns error %v", err) - return err - } - - switch pt.contentType { - case RecordTypeHandshake: - logf(logTypeHandshake, "Received post-handshake message") - // We do not support fragmentation of post-handshake handshake messages. - // TODO: Factor this more elegantly; coalesce with handshakeLayer.ReadMessage() - start := 0 - headerLen := handshakeHeaderLenTLS - if c.config.UseDTLS { - headerLen = handshakeHeaderLenDTLS - } - for start < len(pt.fragment) { - if len(pt.fragment[start:]) < headerLen { - return fmt.Errorf("Post-handshake handshake message too short for header") - } - - hm := &HandshakeMessage{} - hm.msgType = HandshakeType(pt.fragment[start]) - hmLen := (int(pt.fragment[start+1]) << 16) + (int(pt.fragment[start+2]) << 8) + int(pt.fragment[start+3]) - - if len(pt.fragment[start+headerLen:]) < hmLen { - return fmt.Errorf("Post-handshake handshake message too short for body") - } - hm.body = pt.fragment[start+headerLen : start+headerLen+hmLen] - - // XXX: If we want to support more advanced cases, e.g., post-handshake - // authentication, we'll need to allow transitions other than - // Connected -> Connected - state, actions, alert := c.state.ProcessMessage(hm) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error in state transition: %v", alert) - c.sendAlert(alert) - return io.EOF - } - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return io.EOF - } - } - - var connected bool - c.state, connected = state.(stateConnected) - if !connected { - logf(logTypeHandshake, "Disconnected after state transition: %v", alert) - c.sendAlert(alert) - return io.EOF - } - - start += headerLen + hmLen - } - case RecordTypeAlert: - logf(logTypeIO, "extended buffer (for alert): [%d] %x", len(c.readBuffer), c.readBuffer) - if len(pt.fragment) != 2 { - c.sendAlert(AlertUnexpectedMessage) - return io.EOF - } - if Alert(pt.fragment[1]) == AlertCloseNotify { - return io.EOF - } - - switch pt.fragment[0] { - case AlertLevelWarning: - // drop on the floor - case AlertLevelError: - return Alert(pt.fragment[1]) - default: - c.sendAlert(AlertUnexpectedMessage) - return io.EOF - } - - case RecordTypeAck: - if !c.hsCtx.hIn.datagram { - logf(logTypeHandshake, "Received ACK in TLS mode") - return AlertUnexpectedMessage - } - return c.hsCtx.processAck(pt.fragment) - - case RecordTypeApplicationData: - c.readBuffer = append(c.readBuffer, pt.fragment...) - logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) - - } - - return err -} - -func readPartial(in *[]byte, buffer []byte) int { - logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in))) - read := copy(buffer, *in) - *in = (*in)[read:] - - logf(logTypeVerbose, "Returning %v", string(buffer)) - return read -} - -// Read application data up to the size of buffer. Handshake and alert records -// are consumed by the Conn object directly. -func (c *Conn) Read(buffer []byte) (int, error) { - if _, connected := c.hState.(stateConnected); !connected { - // Clients can't call Read prior to handshake completion. - if c.isClient { - return 0, errors.New("Read called before the handshake completed") - } - - // Neither can servers that don't allow early data. - if !c.config.AllowEarlyData { - return 0, errors.New("Read called before the handshake completed") - } - - // If there's no early data, then return WouldBlock - if len(c.hsCtx.earlyData) == 0 { - return 0, AlertWouldBlock - } - - return readPartial(&c.hsCtx.earlyData, buffer), nil - } - - // The handshake is now connected. - logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) - if alert := c.Handshake(); alert != AlertNoAlert { - return 0, alert - } - - if len(buffer) == 0 { - return 0, nil - } - - // Run our timers. - if c.config.UseDTLS { - if err := c.hsCtx.timers.check(time.Now()); err != nil { - return 0, AlertInternalError - } - } - - // Lock the input channel - c.in.Lock() - defer c.in.Unlock() - for len(c.readBuffer) == 0 { - err := c.consumeRecord() - - // err can be nil if consumeRecord processed a non app-data - // record. - if err != nil { - if c.config.NonBlocking || err != AlertWouldBlock { - logf(logTypeIO, "conn.Read returns err=%v", err) - return 0, err - } - } - } - - return readPartial(&c.readBuffer, buffer), nil -} - -// Write application data -func (c *Conn) Write(buffer []byte) (int, error) { - // Lock the output channel - c.out.Lock() - defer c.out.Unlock() - - if !c.Writable() { - return 0, errors.New("Write called before the handshake completed (and early data not in use)") - } - - // Send full-size fragments - var start int - sent := 0 - for start = 0; len(buffer)-start >= maxFragmentLen; start += maxFragmentLen { - err := c.out.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeApplicationData, - fragment: buffer[start : start+maxFragmentLen], - }) - - if err != nil { - return sent, err - } - sent += maxFragmentLen - } - - // Send a final partial fragment if necessary - if start < len(buffer) { - err := c.out.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeApplicationData, - fragment: buffer[start:], - }) - - if err != nil { - return sent, err - } - sent += len(buffer[start:]) - } - return sent, nil -} - -// sendAlert sends a TLS alert message. -// c.out.Mutex <= L. -func (c *Conn) sendAlert(err Alert) error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - var level int - switch err { - case AlertNoRenegotiation, AlertCloseNotify: - level = AlertLevelWarning - default: - level = AlertLevelError - } - - buf := []byte{byte(err), byte(level)} - c.out.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeAlert, - fragment: buf, - }) - - // close_notify and end_of_early_data are not actually errors - if level == AlertLevelWarning { - return &net.OpError{Op: "local error", Err: err} - } - - return c.Close() -} - -// Close closes the connection. -func (c *Conn) Close() error { - // XXX crypto/tls has an interlock with Write here. Do we need that? - - return c.conn.Close() -} - -// LocalAddr returns the local network address. -func (c *Conn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -// RemoteAddr returns the remote network address. -func (c *Conn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline sets the read and write deadlines associated with the connection. -// A zero value for t means Read and Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -// SetReadDeadline sets the read deadline on the underlying connection. -// A zero value for t means Read will not time out. -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the write deadline on the underlying connection. -// A zero value for t means Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { - label := "[server]" - if c.isClient { - label = "[client]" - } - - switch action := actionGeneric.(type) { - case QueueHandshakeMessage: - logf(logTypeHandshake, "%s queuing handshake message type=%v", label, action.Message.msgType) - err := c.hsCtx.hOut.QueueMessage(action.Message) - if err != nil { - logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) - return AlertInternalError - } - - case SendQueuedHandshake: - _, err := c.hsCtx.hOut.SendQueuedMessages() - if err != nil { - logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) - return AlertInternalError - } - if c.config.UseDTLS { - c.hsCtx.timers.start(retransmitTimerLabel, - c.hsCtx.handshakeRetransmit, - c.hsCtx.timeoutMS) - } - case RekeyIn: - logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet) - // Check that we don't have an input data in the handshake frame parser. - if len(c.hsCtx.hIn.frame.remainder) > 0 { - logf(logTypeHandshake, "%s Rekey with data still in handshake buffers", label) - return AlertDecodeError - } - err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) - if err != nil { - logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) - return AlertInternalError - } - - case RekeyOut: - logf(logTypeHandshake, "%s Rekeying out to %s: %+v", label, action.epoch.label(), action.KeySet) - err := c.out.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) - if err != nil { - logf(logTypeHandshake, "%s Unable to rekey outbound: %v", label, err) - return AlertInternalError - } - - case ResetOut: - logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq) - c.out.ResetClear(action.seq) - - case StorePSK: - logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) - if c.isClient { - // Clients look up PSKs based on server name - c.config.PSKs.Put(c.config.ServerName, action.PSK) - } else { - // Servers look them up based on the identity in the extension - c.config.PSKs.Put(hex.EncodeToString(action.PSK.Identity), action.PSK) - } - - default: - logf(logTypeHandshake, "%s Unknown action type", label) - assert(false) - return AlertInternalError - } - - return AlertNoAlert -} - -func (c *Conn) HandshakeSetup() Alert { - var state HandshakeState - var actions []HandshakeAction - var alert Alert - - if err := c.config.Init(c.isClient); err != nil { - logf(logTypeHandshake, "Error initializing config: %v", err) - return AlertInternalError - } - - opts := ConnectionOptions{ - ServerName: c.config.ServerName, - NextProtos: c.config.NextProtos, - } - - if c.isClient { - state, actions, alert = clientStateStart{Config: c.config, Opts: opts, hsCtx: c.hsCtx}.Next(nil) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error initializing client state: %v", alert) - return alert - } - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - return alert - } - } - } else { - if c.config.RequireCookie && c.config.CookieProtector == nil { - logf(logTypeHandshake, "RequireCookie set, but no CookieProtector provided. Using default cookie protector. Stateless Retry not possible.") - if c.config.NonBlocking { - logf(logTypeHandshake, "Not possible in non-blocking mode.") - return AlertInternalError - } - var err error - c.config.CookieProtector, err = NewDefaultCookieProtector() - if err != nil { - logf(logTypeHandshake, "Error initializing cookie source: %v", alert) - return AlertInternalError - } - } - state = serverStateStart{Config: c.config, conn: c, hsCtx: c.hsCtx} - } - - c.hState = state - return AlertNoAlert -} - -type handshakeMessageReader interface { - ReadMessage() (*HandshakeMessage, Alert) -} - -type handshakeMessageReaderImpl struct { - hsCtx *HandshakeContext -} - -var _ handshakeMessageReader = &handshakeMessageReaderImpl{} - -func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) { - var hm *HandshakeMessage - var err error - for { - hm, err = r.hsCtx.hIn.ReadMessage() - if err == AlertWouldBlock { - return nil, AlertWouldBlock - } - if err != nil { - logf(logTypeHandshake, "Error reading message: %v", err) - return nil, AlertCloseNotify - } - if hm != nil { - break - } - } - - return hm, AlertNoAlert -} - -// Handshake causes a TLS handshake on the connection. The `isClient` member -// determines whether a client or server handshake is performed. If a -// handshake has already been performed, then its result will be returned. -func (c *Conn) Handshake() Alert { - label := "[server]" - if c.isClient { - label = "[client]" - } - - // TODO Lock handshakeMutex - // TODO Remove CloseNotify hack - if c.handshakeAlert != AlertNoAlert && c.handshakeAlert != AlertCloseNotify { - logf(logTypeHandshake, "Pre-existing handshake error: %v", c.handshakeAlert) - return c.handshakeAlert - } - if c.handshakeComplete { - return AlertNoAlert - } - - if c.hState == nil { - logf(logTypeHandshake, "%s First time through handshake (or after stateless retry), setting up", label) - alert := c.HandshakeSetup() - if alert != AlertNoAlert || (c.isClient && c.config.NonBlocking) { - return alert - } - } - - logf(logTypeHandshake, "(Re-)entering handshake, state=%v", c.hState) - state := c.hState - _, connected := state.(stateConnected) - - hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx} - for !connected { - var alert Alert - var actions []HandshakeAction - - // Advance the state machine - state, actions, alert = state.Next(hmr) - if alert == AlertWouldBlock { - logf(logTypeHandshake, "%s Would block reading message: %s", label, alert) - // If we blocked, then run our timers to see if any have expired. - if c.hsCtx.hIn.datagram { - if err := c.hsCtx.timers.check(time.Now()); err != nil { - return AlertInternalError - } - } - return AlertWouldBlock - } - if alert == AlertCloseNotify { - logf(logTypeHandshake, "%s Error reading message: %s", label, alert) - c.sendAlert(AlertCloseNotify) - return AlertCloseNotify - } - if alert != AlertNoAlert && alert != AlertStatelessRetry { - logf(logTypeHandshake, "Error in state transition: %v", alert) - return alert - } - - for index, action := range actions { - logf(logTypeHandshake, "%s taking next action (%d)", label, index) - if alert := c.takeAction(action); alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return alert - } - } - - c.hState = state - logf(logTypeHandshake, "state is now %s", c.GetHsState()) - _, connected = state.(stateConnected) - if connected { - c.state = state.(stateConnected) - c.handshakeComplete = true - - if !c.isClient { - // Send NewSessionTicket if configured to - if c.config.SendSessionTickets { - actions, alert := c.state.NewSessionTicket( - c.config.TicketLen, - c.config.TicketLifetime, - c.config.EarlyDataLifetime) - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return alert - } - } - } - - // If there is early data, move it into the main buffer - if c.hsCtx.earlyData != nil { - c.readBuffer = c.hsCtx.earlyData - c.hsCtx.earlyData = nil - } - - } else { - assert(c.hsCtx.earlyData == nil) - } - } - - if c.config.NonBlocking { - if alert == AlertStatelessRetry { - return AlertStatelessRetry - } - return AlertNoAlert - } - } - - return AlertNoAlert -} - -func (c *Conn) SendKeyUpdate(requestUpdate bool) error { - if !c.handshakeComplete { - return fmt.Errorf("Cannot update keys until after handshake") - } - - request := KeyUpdateNotRequested - if requestUpdate { - request = KeyUpdateRequested - } - - // Create the key update and update state - actions, alert := c.state.KeyUpdate(request) - if alert != AlertNoAlert { - c.sendAlert(alert) - return fmt.Errorf("Alert while generating key update: %v", alert) - } - - // Take actions (send key update and rekey) - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - c.sendAlert(alert) - return fmt.Errorf("Alert during key update actions: %v", alert) - } - } - - return nil -} - -func (c *Conn) GetHsState() State { - if c.hState == nil { - return StateInit - } - return c.hState.State() -} - -func (c *Conn) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) { - _, connected := c.hState.(stateConnected) - if !connected { - return nil, fmt.Errorf("Cannot compute exporter when state is not connected") - } - - if c.state.exporterSecret == nil { - return nil, fmt.Errorf("Internal error: no exporter secret") - } - - h0 := c.state.cryptoParams.Hash.New().Sum(nil) - tmpSecret := deriveSecret(c.state.cryptoParams, c.state.exporterSecret, label, h0) - - hc := c.state.cryptoParams.Hash.New().Sum(context) - return HkdfExpandLabel(c.state.cryptoParams.Hash, tmpSecret, "exporter", hc, keyLength), nil -} - -func (c *Conn) ConnectionState() ConnectionState { - state := ConnectionState{ - HandshakeState: c.GetHsState(), - } - - if c.handshakeComplete { - state.CipherSuite = cipherSuiteMap[c.state.Params.CipherSuite] - state.NextProto = c.state.Params.NextProto - state.VerifiedChains = c.state.verifiedChains - state.PeerCertificates = c.state.peerCertificates - state.UsingPSK = c.state.Params.UsingPSK - state.UsingEarlyData = c.state.Params.UsingEarlyData - } - - return state -} - -func (c *Conn) Writable() bool { - // If we're connected, we're writable. - if _, connected := c.hState.(stateConnected); connected { - return true - } - - // If we're a client in 0-RTT, then we're writable. - if c.isClient && c.out.cipher.epoch == EpochEarlyData { - return true - } - - return false -} - -func (c *Conn) label() string { - if c.isClient { - return "client" - } - return "server" -} diff --git a/vendor/github.com/bifurcation/mint/cookie-protector.go b/vendor/github.com/bifurcation/mint/cookie-protector.go deleted file mode 100644 index 73dd80ba..00000000 --- a/vendor/github.com/bifurcation/mint/cookie-protector.go +++ /dev/null @@ -1,86 +0,0 @@ -package mint - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha256" - "fmt" - "io" - - "golang.org/x/crypto/hkdf" -) - -// CookieProtector is used to create and verify a cookie -type CookieProtector interface { - // NewToken creates a new token - NewToken([]byte) ([]byte, error) - // DecodeToken decodes a token - DecodeToken([]byte) ([]byte, error) -} - -const cookieSecretSize = 32 -const cookieNonceSize = 32 - -// The DefaultCookieProtector is a simple implementation for the CookieProtector. -type DefaultCookieProtector struct { - secret []byte -} - -var _ CookieProtector = &DefaultCookieProtector{} - -// NewDefaultCookieProtector creates a source for source address tokens -func NewDefaultCookieProtector() (CookieProtector, error) { - secret := make([]byte, cookieSecretSize) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - return &DefaultCookieProtector{secret: secret}, nil -} - -// NewToken encodes data into a new token. -func (s *DefaultCookieProtector) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, cookieNonceSize) - if _, err := rand.Read(nonce); err != nil { - return nil, err - } - aead, aeadNonce, err := s.createAEAD(nonce) - if err != nil { - return nil, err - } - return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil -} - -// DecodeToken decodes a token. -func (s *DefaultCookieProtector) DecodeToken(p []byte) ([]byte, error) { - if len(p) < cookieNonceSize { - return nil, fmt.Errorf("Token too short: %d", len(p)) - } - nonce := p[:cookieNonceSize] - aead, aeadNonce, err := s.createAEAD(nonce) - if err != nil { - return nil, err - } - return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil) -} - -func (s *DefaultCookieProtector) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { - h := hkdf.New(sha256.New, s.secret, nonce, []byte("mint cookie source")) - key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 - if _, err := io.ReadFull(h, key); err != nil { - return nil, nil, err - } - aeadNonce := make([]byte, 12) - if _, err := io.ReadFull(h, aeadNonce); err != nil { - return nil, nil, err - } - c, err := aes.NewCipher(key) - if err != nil { - return nil, nil, err - } - aead, err := cipher.NewGCM(c) - if err != nil { - return nil, nil, err - } - return aead, aeadNonce, nil -} diff --git a/vendor/github.com/bifurcation/mint/crypto.go b/vendor/github.com/bifurcation/mint/crypto.go deleted file mode 100644 index ef7397d8..00000000 --- a/vendor/github.com/bifurcation/mint/crypto.go +++ /dev/null @@ -1,667 +0,0 @@ -package mint - -import ( - "bytes" - "crypto" - "crypto/aes" - "crypto/cipher" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/hmac" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "fmt" - "math/big" - "time" - - "golang.org/x/crypto/curve25519" - - // Blank includes to ensure hash support - _ "crypto/sha1" - _ "crypto/sha256" - _ "crypto/sha512" -) - -var prng = rand.Reader - -type aeadFactory func(key []byte) (cipher.AEAD, error) - -type CipherSuiteParams struct { - Suite CipherSuite - Cipher aeadFactory // Cipher factory - Hash crypto.Hash // Hash function - KeyLen int // Key length in octets - IvLen int // IV length in octets -} - -type signatureAlgorithm uint8 - -const ( - signatureAlgorithmUnknown = iota - signatureAlgorithmRSA_PKCS1 - signatureAlgorithmRSA_PSS - signatureAlgorithmECDSA -) - -var ( - hashMap = map[SignatureScheme]crypto.Hash{ - RSA_PKCS1_SHA1: crypto.SHA1, - RSA_PKCS1_SHA256: crypto.SHA256, - RSA_PKCS1_SHA384: crypto.SHA384, - RSA_PKCS1_SHA512: crypto.SHA512, - ECDSA_P256_SHA256: crypto.SHA256, - ECDSA_P384_SHA384: crypto.SHA384, - ECDSA_P521_SHA512: crypto.SHA512, - RSA_PSS_SHA256: crypto.SHA256, - RSA_PSS_SHA384: crypto.SHA384, - RSA_PSS_SHA512: crypto.SHA512, - } - - sigMap = map[SignatureScheme]signatureAlgorithm{ - RSA_PKCS1_SHA1: signatureAlgorithmRSA_PKCS1, - RSA_PKCS1_SHA256: signatureAlgorithmRSA_PKCS1, - RSA_PKCS1_SHA384: signatureAlgorithmRSA_PKCS1, - RSA_PKCS1_SHA512: signatureAlgorithmRSA_PKCS1, - ECDSA_P256_SHA256: signatureAlgorithmECDSA, - ECDSA_P384_SHA384: signatureAlgorithmECDSA, - ECDSA_P521_SHA512: signatureAlgorithmECDSA, - RSA_PSS_SHA256: signatureAlgorithmRSA_PSS, - RSA_PSS_SHA384: signatureAlgorithmRSA_PSS, - RSA_PSS_SHA512: signatureAlgorithmRSA_PSS, - } - - curveMap = map[SignatureScheme]NamedGroup{ - ECDSA_P256_SHA256: P256, - ECDSA_P384_SHA384: P384, - ECDSA_P521_SHA512: P521, - } - - newAESGCM = func(key []byte) (cipher.AEAD, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - // TLS always uses 12-byte nonces - return cipher.NewGCMWithNonceSize(block, 12) - } - - cipherSuiteMap = map[CipherSuite]CipherSuiteParams{ - TLS_AES_128_GCM_SHA256: { - Suite: TLS_AES_128_GCM_SHA256, - Cipher: newAESGCM, - Hash: crypto.SHA256, - KeyLen: 16, - IvLen: 12, - }, - TLS_AES_256_GCM_SHA384: { - Suite: TLS_AES_256_GCM_SHA384, - Cipher: newAESGCM, - Hash: crypto.SHA384, - KeyLen: 32, - IvLen: 12, - }, - } - - x509AlgMap = map[SignatureScheme]x509.SignatureAlgorithm{ - RSA_PKCS1_SHA1: x509.SHA1WithRSA, - RSA_PKCS1_SHA256: x509.SHA256WithRSA, - RSA_PKCS1_SHA384: x509.SHA384WithRSA, - RSA_PKCS1_SHA512: x509.SHA512WithRSA, - ECDSA_P256_SHA256: x509.ECDSAWithSHA256, - ECDSA_P384_SHA384: x509.ECDSAWithSHA384, - ECDSA_P521_SHA512: x509.ECDSAWithSHA512, - } - - defaultRSAKeySize = 2048 -) - -func curveFromNamedGroup(group NamedGroup) (crv elliptic.Curve) { - switch group { - case P256: - crv = elliptic.P256() - case P384: - crv = elliptic.P384() - case P521: - crv = elliptic.P521() - } - return -} - -func namedGroupFromECDSAKey(key *ecdsa.PublicKey) (g NamedGroup) { - switch key.Curve.Params().Name { - case elliptic.P256().Params().Name: - g = P256 - case elliptic.P384().Params().Name: - g = P384 - case elliptic.P521().Params().Name: - g = P521 - } - return -} - -func keyExchangeSizeFromNamedGroup(group NamedGroup) (size int) { - size = 0 - switch group { - case X25519: - size = 32 - case P256: - size = 65 - case P384: - size = 97 - case P521: - size = 133 - case FFDHE2048: - size = 256 - case FFDHE3072: - size = 384 - case FFDHE4096: - size = 512 - case FFDHE6144: - size = 768 - case FFDHE8192: - size = 1024 - } - return -} - -func primeFromNamedGroup(group NamedGroup) (p *big.Int) { - switch group { - case FFDHE2048: - p = finiteFieldPrime2048 - case FFDHE3072: - p = finiteFieldPrime3072 - case FFDHE4096: - p = finiteFieldPrime4096 - case FFDHE6144: - p = finiteFieldPrime6144 - case FFDHE8192: - p = finiteFieldPrime8192 - } - return -} - -func schemeValidForKey(alg SignatureScheme, key crypto.Signer) bool { - sigType := sigMap[alg] - switch key.(type) { - case *rsa.PrivateKey: - return sigType == signatureAlgorithmRSA_PKCS1 || sigType == signatureAlgorithmRSA_PSS - case *ecdsa.PrivateKey: - return sigType == signatureAlgorithmECDSA - default: - return false - } -} - -func ffdheKeyShareFromPrime(p *big.Int) (priv, pub *big.Int, err error) { - primeLen := len(p.Bytes()) - for { - // g = 2 for all ffdhe groups - priv, err = rand.Int(prng, p) - if err != nil { - return - } - - pub = big.NewInt(0) - pub.Exp(big.NewInt(2), priv, p) - - if len(pub.Bytes()) == primeLen { - return - } - } -} - -func newKeyShare(group NamedGroup) (pub []byte, priv []byte, err error) { - switch group { - case P256, P384, P521: - var x, y *big.Int - crv := curveFromNamedGroup(group) - priv, x, y, err = elliptic.GenerateKey(crv, prng) - if err != nil { - return - } - - pub = elliptic.Marshal(crv, x, y) - return - - case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: - p := primeFromNamedGroup(group) - x, X, err2 := ffdheKeyShareFromPrime(p) - if err2 != nil { - err = err2 - return - } - - priv = x.Bytes() - pubBytes := X.Bytes() - - numBytes := keyExchangeSizeFromNamedGroup(group) - - pub = make([]byte, numBytes) - copy(pub[numBytes-len(pubBytes):], pubBytes) - - return - - case X25519: - var private, public [32]byte - _, err = prng.Read(private[:]) - if err != nil { - return - } - - curve25519.ScalarBaseMult(&public, &private) - priv = private[:] - pub = public[:] - return - - default: - return nil, nil, fmt.Errorf("tls.newkeyshare: Unsupported group %v", group) - } -} - -func keyAgreement(group NamedGroup, pub []byte, priv []byte) ([]byte, error) { - switch group { - case P256, P384, P521: - if len(pub) != keyExchangeSizeFromNamedGroup(group) { - return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") - } - - crv := curveFromNamedGroup(group) - pubX, pubY := elliptic.Unmarshal(crv, pub) - x, _ := crv.Params().ScalarMult(pubX, pubY, priv) - xBytes := x.Bytes() - - numBytes := len(crv.Params().P.Bytes()) - - ret := make([]byte, numBytes) - copy(ret[numBytes-len(xBytes):], xBytes) - - return ret, nil - - case FFDHE2048, FFDHE3072, FFDHE4096, FFDHE6144, FFDHE8192: - numBytes := keyExchangeSizeFromNamedGroup(group) - if len(pub) != numBytes { - return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") - } - p := primeFromNamedGroup(group) - x := big.NewInt(0).SetBytes(priv) - Y := big.NewInt(0).SetBytes(pub) - ZBytes := big.NewInt(0).Exp(Y, x, p).Bytes() - - ret := make([]byte, numBytes) - copy(ret[numBytes-len(ZBytes):], ZBytes) - - return ret, nil - - case X25519: - if len(pub) != keyExchangeSizeFromNamedGroup(group) { - return nil, fmt.Errorf("tls.keyagreement: Wrong public key size") - } - - var private, public, ret [32]byte - copy(private[:], priv) - copy(public[:], pub) - curve25519.ScalarMult(&ret, &private, &public) - - return ret[:], nil - - default: - return nil, fmt.Errorf("tls.keyagreement: Unsupported group %v", group) - } -} - -func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { - switch sig { - case RSA_PKCS1_SHA1, RSA_PKCS1_SHA256, - RSA_PKCS1_SHA384, RSA_PKCS1_SHA512, - RSA_PSS_SHA256, RSA_PSS_SHA384, - RSA_PSS_SHA512: - return rsa.GenerateKey(prng, defaultRSAKeySize) - case ECDSA_P256_SHA256: - return ecdsa.GenerateKey(elliptic.P256(), prng) - case ECDSA_P384_SHA384: - return ecdsa.GenerateKey(elliptic.P384(), prng) - case ECDSA_P521_SHA512: - return ecdsa.GenerateKey(elliptic.P521(), prng) - default: - return nil, fmt.Errorf("tls.newsigningkey: Unsupported signature algorithm [%04x]", sig) - } -} - -// XXX(rlb): Copied from crypto/x509 -type ecdsaSignature struct { - R, S *big.Int -} - -func sign(alg SignatureScheme, privateKey crypto.Signer, sigInput []byte) ([]byte, error) { - var opts crypto.SignerOpts - - hash := hashMap[alg] - if hash == crypto.SHA1 { - return nil, fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") - } - - sigType := sigMap[alg] - var realInput []byte - switch key := privateKey.(type) { - case *rsa.PrivateKey: - switch { - case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: - logf(logTypeCrypto, "signing with PKCS1, hashSize=[%d]", hash.Size()) - opts = hash - case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: - fallthrough - case sigType == signatureAlgorithmRSA_PSS: - logf(logTypeCrypto, "signing with PSS, hashSize=[%d]", hash.Size()) - opts = &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} - default: - return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for RSA key") - } - - h := hash.New() - h.Write(sigInput) - realInput = h.Sum(nil) - case *ecdsa.PrivateKey: - if sigType != signatureAlgorithmECDSA { - return nil, fmt.Errorf("tls.crypto.sign: Unsupported algorithm for ECDSA key") - } - - algGroup := curveMap[alg] - keyGroup := namedGroupFromECDSAKey(key.Public().(*ecdsa.PublicKey)) - if algGroup != keyGroup { - return nil, fmt.Errorf("tls.crypto.sign: Unsupported hash/curve combination") - } - - h := hash.New() - h.Write(sigInput) - realInput = h.Sum(nil) - default: - return nil, fmt.Errorf("tls.crypto.sign: Unsupported private key type") - } - - sig, err := privateKey.Sign(prng, realInput, opts) - logf(logTypeCrypto, "signature: %x", sig) - return sig, err -} - -func verify(alg SignatureScheme, publicKey crypto.PublicKey, sigInput []byte, sig []byte) error { - hash := hashMap[alg] - - if hash == crypto.SHA1 { - return fmt.Errorf("tls.crypt.sign: Use of SHA-1 is forbidden") - } - - sigType := sigMap[alg] - switch pub := publicKey.(type) { - case *rsa.PublicKey: - switch { - case allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: - logf(logTypeCrypto, "verifying with PKCS1, hashSize=[%d]", hash.Size()) - - h := hash.New() - h.Write(sigInput) - realInput := h.Sum(nil) - return rsa.VerifyPKCS1v15(pub, hash, realInput, sig) - case !allowPKCS1 && sigType == signatureAlgorithmRSA_PKCS1: - fallthrough - case sigType == signatureAlgorithmRSA_PSS: - logf(logTypeCrypto, "verifying with PSS, hashSize=[%d]", hash.Size()) - opts := &rsa.PSSOptions{SaltLength: hash.Size(), Hash: hash} - - h := hash.New() - h.Write(sigInput) - realInput := h.Sum(nil) - return rsa.VerifyPSS(pub, hash, realInput, sig, opts) - default: - return fmt.Errorf("tls.verify: Unsupported algorithm for RSA key") - } - - case *ecdsa.PublicKey: - if sigType != signatureAlgorithmECDSA { - return fmt.Errorf("tls.verify: Unsupported algorithm for ECDSA key") - } - - if curveMap[alg] != namedGroupFromECDSAKey(pub) { - return fmt.Errorf("tls.verify: Unsupported curve for ECDSA key") - } - - ecdsaSig := new(ecdsaSignature) - if rest, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { - return err - } else if len(rest) != 0 { - return fmt.Errorf("tls.verify: trailing data after ECDSA signature") - } - if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { - return fmt.Errorf("tls.verify: ECDSA signature contained zero or negative values") - } - - h := hash.New() - h.Write(sigInput) - realInput := h.Sum(nil) - if !ecdsa.Verify(pub, realInput, ecdsaSig.R, ecdsaSig.S) { - return fmt.Errorf("tls.verify: ECDSA verification failure") - } - return nil - default: - return fmt.Errorf("tls.verify: Unsupported key type") - } -} - -// 0 -// | -// v -// PSK -> HKDF-Extract = Early Secret -// | -// +-----> Derive-Secret(., -// | "ext binder" | -// | "res binder", -// | "") -// | = binder_key -// | -// +-----> Derive-Secret(., "c e traffic", -// | ClientHello) -// | = client_early_traffic_secret -// | -// +-----> Derive-Secret(., "e exp master", -// | ClientHello) -// | = early_exporter_master_secret -// v -// Derive-Secret(., "derived", "") -// | -// v -// (EC)DHE -> HKDF-Extract = Handshake Secret -// | -// +-----> Derive-Secret(., "c hs traffic", -// | ClientHello...ServerHello) -// | = client_handshake_traffic_secret -// | -// +-----> Derive-Secret(., "s hs traffic", -// | ClientHello...ServerHello) -// | = server_handshake_traffic_secret -// v -// Derive-Secret(., "derived", "") -// | -// v -// 0 -> HKDF-Extract = Master Secret -// | -// +-----> Derive-Secret(., "c ap traffic", -// | ClientHello...server Finished) -// | = client_application_traffic_secret_0 -// | -// +-----> Derive-Secret(., "s ap traffic", -// | ClientHello...server Finished) -// | = server_application_traffic_secret_0 -// | -// +-----> Derive-Secret(., "exp master", -// | ClientHello...server Finished) -// | = exporter_master_secret -// | -// +-----> Derive-Secret(., "res master", -// ClientHello...client Finished) -// = resumption_master_secret - -// From RFC 5869 -// PRK = HMAC-Hash(salt, IKM) -func HkdfExtract(hash crypto.Hash, saltIn, input []byte) []byte { - salt := saltIn - - // if [salt is] not provided, it is set to a string of HashLen zeros - if salt == nil { - salt = bytes.Repeat([]byte{0}, hash.Size()) - } - - h := hmac.New(hash.New, salt) - h.Write(input) - out := h.Sum(nil) - - logf(logTypeCrypto, "HKDF Extract:\n") - logf(logTypeCrypto, "Salt [%d]: %x\n", len(salt), salt) - logf(logTypeCrypto, "Input [%d]: %x\n", len(input), input) - logf(logTypeCrypto, "Output [%d]: %x\n", len(out), out) - - return out -} - -const ( - labelExternalBinder = "ext binder" - labelResumptionBinder = "res binder" - labelEarlyTrafficSecret = "c e traffic" - labelEarlyExporterSecret = "e exp master" - labelClientHandshakeTrafficSecret = "c hs traffic" - labelServerHandshakeTrafficSecret = "s hs traffic" - labelClientApplicationTrafficSecret = "c ap traffic" - labelServerApplicationTrafficSecret = "s ap traffic" - labelExporterSecret = "exp master" - labelResumptionSecret = "res master" - labelDerived = "derived" - labelFinished = "finished" - labelResumption = "resumption" -) - -// struct HkdfLabel { -// uint16 length; -// opaque label<9..255>; -// opaque hash_value<0..255>; -// }; -func hkdfEncodeLabel(labelIn string, hashValue []byte, outLen int) []byte { - label := "tls13 " + labelIn - - labelLen := len(label) - hashLen := len(hashValue) - hkdfLabel := make([]byte, 2+1+labelLen+1+hashLen) - hkdfLabel[0] = byte(outLen >> 8) - hkdfLabel[1] = byte(outLen) - hkdfLabel[2] = byte(labelLen) - copy(hkdfLabel[3:3+labelLen], []byte(label)) - hkdfLabel[3+labelLen] = byte(hashLen) - copy(hkdfLabel[3+labelLen+1:], hashValue) - - return hkdfLabel -} - -func HkdfExpand(hash crypto.Hash, prk, info []byte, outLen int) []byte { - out := []byte{} - T := []byte{} - i := byte(1) - for len(out) < outLen { - block := append(T, info...) - block = append(block, i) - - h := hmac.New(hash.New, prk) - h.Write(block) - - T = h.Sum(nil) - out = append(out, T...) - i++ - } - return out[:outLen] -} - -func HkdfExpandLabel(hash crypto.Hash, secret []byte, label string, hashValue []byte, outLen int) []byte { - info := hkdfEncodeLabel(label, hashValue, outLen) - derived := HkdfExpand(hash, secret, info, outLen) - - logf(logTypeCrypto, "HKDF Expand: label=[tls13 ] + '%s',requested length=%d\n", label, outLen) - logf(logTypeCrypto, "PRK [%d]: %x\n", len(secret), secret) - logf(logTypeCrypto, "Hash [%d]: %x\n", len(hashValue), hashValue) - logf(logTypeCrypto, "Info [%d]: %x\n", len(info), info) - logf(logTypeCrypto, "Derived key [%d]: %x\n", len(derived), derived) - - return derived -} - -func deriveSecret(params CipherSuiteParams, secret []byte, label string, messageHash []byte) []byte { - return HkdfExpandLabel(params.Hash, secret, label, messageHash, params.Hash.Size()) -} - -func computeFinishedData(params CipherSuiteParams, baseKey []byte, input []byte) []byte { - macKey := HkdfExpandLabel(params.Hash, baseKey, labelFinished, []byte{}, params.Hash.Size()) - mac := hmac.New(params.Hash.New, macKey) - mac.Write(input) - return mac.Sum(nil) -} - -type keySet struct { - cipher aeadFactory - key []byte - iv []byte -} - -func makeTrafficKeys(params CipherSuiteParams, secret []byte) keySet { - logf(logTypeCrypto, "making traffic keys: secret=%x", secret) - return keySet{ - cipher: params.Cipher, - key: HkdfExpandLabel(params.Hash, secret, "key", []byte{}, params.KeyLen), - iv: HkdfExpandLabel(params.Hash, secret, "iv", []byte{}, params.IvLen), - } -} - -func MakeNewSelfSignedCert(name string, alg SignatureScheme) (crypto.Signer, *x509.Certificate, error) { - priv, err := newSigningKey(alg) - if err != nil { - return nil, nil, err - } - - cert, err := newSelfSigned(name, alg, priv) - if err != nil { - return nil, nil, err - } - return priv, cert, nil -} - -func newSelfSigned(name string, alg SignatureScheme, priv crypto.Signer) (*x509.Certificate, error) { - sigAlg, ok := x509AlgMap[alg] - if !ok { - return nil, fmt.Errorf("tls.selfsigned: Unknown signature algorithm [%04x]", alg) - } - if len(name) == 0 { - return nil, fmt.Errorf("tls.selfsigned: No name provided") - } - - serial, err := rand.Int(rand.Reader, big.NewInt(0xA0A0A0A0)) - if err != nil { - return nil, err - } - - template := &x509.Certificate{ - SerialNumber: serial, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, 1), - SignatureAlgorithm: sigAlg, - Subject: pkix.Name{CommonName: name}, - DNSNames: []string{name}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - } - der, err := x509.CreateCertificate(prng, template, template, priv.Public(), priv) - if err != nil { - return nil, err - } - - // It is safe to ignore the error here because we're parsing known-good data - cert, _ := x509.ParseCertificate(der) - return cert, nil -} diff --git a/vendor/github.com/bifurcation/mint/dtls.go b/vendor/github.com/bifurcation/mint/dtls.go deleted file mode 100644 index aa914e3e..00000000 --- a/vendor/github.com/bifurcation/mint/dtls.go +++ /dev/null @@ -1,222 +0,0 @@ -package mint - -import ( - "fmt" - "github.com/bifurcation/mint/syntax" - "time" -) - -const ( - initialMtu = 1200 - initialTimeout = 100 -) - -// labels for timers -const ( - retransmitTimerLabel = "handshake retransmit" - ackTimerLabel = "ack timer" -) - -type SentHandshakeFragment struct { - seq uint32 - offset int - fragLength int - record uint64 - acked bool -} - -type DtlsAck struct { - RecordNumbers []uint64 `tls:"head=2"` -} - -func wireVersion(h *HandshakeLayer) uint16 { - if h.datagram { - return dtls12WireVersion - } - return tls12Version -} - -func dtlsConvertVersion(version uint16) uint16 { - if version == tls12Version { - return dtls12WireVersion - } - if version == tls10Version { - return 0xfeff - } - panic(fmt.Sprintf("Internal error, unexpected version=%d", version)) -} - -// TODO(ekr@rtfm.com): Move these to state-machine.go -func (h *HandshakeContext) handshakeRetransmit() error { - if _, err := h.hOut.SendQueuedMessages(); err != nil { - return err - } - - h.timers.start(retransmitTimerLabel, - h.handshakeRetransmit, - h.timeoutMS) - - // TODO(ekr@rtfm.com): Back off timer - return nil -} - -func (h *HandshakeContext) sendAck() error { - toack := h.hIn.recvdRecords - - count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU - if len(toack) > count { - toack = toack[:count] - } - logf(logTypeHandshake, "Sending ACK: [%x]", toack) - - ack := &DtlsAck{toack} - body, err := syntax.Marshal(&ack) - if err != nil { - return err - } - err = h.hOut.conn.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeAck, - fragment: body, - }) - if err != nil { - return err - } - return nil -} - -func (h *HandshakeContext) processAck(data []byte) error { - // Cancel the retransmit timer because we will be resending - // and possibly re-arming later. - h.timers.cancel(retransmitTimerLabel) - - ack := &DtlsAck{} - read, err := syntax.Unmarshal(data, &ack) - if err != nil { - return err - } - if len(data) != read { - return fmt.Errorf("Invalid encoding: Extra data not consumed") - } - logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers) - - for _, r := range ack.RecordNumbers { - for _, m := range h.sentFragments { - if r == m.record { - logf(logTypeHandshake, "Marking %v %v(%v) as acked", - m.seq, m.offset, m.fragLength) - m.acked = true - } - } - } - - count, err := h.hOut.SendQueuedMessages() - if err != nil { - return err - } - - if count == 0 { - logf(logTypeHandshake, "All messages ACKed") - h.hOut.ClearQueuedMessages() - return nil - } - - // Reset the timer - h.timers.start(retransmitTimerLabel, - h.handshakeRetransmit, - h.timeoutMS) - - return nil -} - -func (c *Conn) GetDTLSTimeout() (bool, time.Duration) { - return c.hsCtx.timers.remaining() -} - -func (h *HandshakeContext) receivedHandshakeMessage() { - logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight) - // This just enables tests. - if h.hIn == nil { - return - } - - if !h.hIn.datagram { - return - } - - if h.waitingNextFlight { - logf(logTypeHandshake, "Received the start of the flight") - - // Clear the outgoing DTLS queue and terminate the retransmit timer - h.hOut.ClearQueuedMessages() - h.timers.cancel(retransmitTimerLabel) - - // OK, we're not waiting any more. - h.waitingNextFlight = false - } - - // Now pre-emptively arm the ACK timer if it's not armed already. - // We'll automatically dis-arm it at the end of the handshake. - if h.timers.getTimer(ackTimerLabel) == nil { - h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4) - } -} - -func (h *HandshakeContext) receivedEndOfFlight() { - logf(logTypeHandshake, "%p Received the end of the flight", h) - if !h.hIn.datagram { - return - } - - // Empty incoming queue - h.hIn.queued = nil - - // Note that we are waiting for the next flight. - h.waitingNextFlight = true - - // Clear the ACK queue. - h.hIn.recvdRecords = nil - - // Disarm the ACK timer - h.timers.cancel(ackTimerLabel) -} - -func (h *HandshakeContext) receivedFinalFlight() { - logf(logTypeHandshake, "%p Received final flight", h) - if !h.hIn.datagram { - return - } - - // Disarm the ACK timer - h.timers.cancel(ackTimerLabel) - - // But send an ACK immediately. - h.sendAck() -} - -func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool { - logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen) - for _, f := range h.sentFragments { - if !f.acked { - continue - } - - if f.seq != seq { - continue - } - - if f.offset > offset { - continue - } - - // At this point, we know that the stored fragment starts - // at or before what we want to send, so check where the end - // is. - if f.offset+f.fragLength < offset+fraglen { - continue - } - - return true - } - - return false -} diff --git a/vendor/github.com/bifurcation/mint/extensions.go b/vendor/github.com/bifurcation/mint/extensions.go deleted file mode 100644 index 07cb16c6..00000000 --- a/vendor/github.com/bifurcation/mint/extensions.go +++ /dev/null @@ -1,626 +0,0 @@ -package mint - -import ( - "bytes" - "fmt" - "github.com/bifurcation/mint/syntax" -) - -type ExtensionBody interface { - Type() ExtensionType - Marshal() ([]byte, error) - Unmarshal(data []byte) (int, error) -} - -// struct { -// ExtensionType extension_type; -// opaque extension_data<0..2^16-1>; -// } Extension; -type Extension struct { - ExtensionType ExtensionType - ExtensionData []byte `tls:"head=2"` -} - -func (ext Extension) Marshal() ([]byte, error) { - return syntax.Marshal(ext) -} - -func (ext *Extension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, ext) -} - -type ExtensionList []Extension - -type extensionListInner struct { - List []Extension `tls:"head=2"` -} - -func (el ExtensionList) Marshal() ([]byte, error) { - return syntax.Marshal(extensionListInner{el}) -} - -func (el *ExtensionList) Unmarshal(data []byte) (int, error) { - var list extensionListInner - read, err := syntax.Unmarshal(data, &list) - if err != nil { - return 0, err - } - - *el = list.List - return read, nil -} - -func (el *ExtensionList) Add(src ExtensionBody) error { - data, err := src.Marshal() - if err != nil { - return err - } - - if el == nil { - el = new(ExtensionList) - } - - // If one already exists with this type, replace it - for i := range *el { - if (*el)[i].ExtensionType == src.Type() { - (*el)[i].ExtensionData = data - return nil - } - } - - // Otherwise append - *el = append(*el, Extension{ - ExtensionType: src.Type(), - ExtensionData: data, - }) - return nil -} - -func (el ExtensionList) Parse(dsts []ExtensionBody) (map[ExtensionType]bool, error) { - found := make(map[ExtensionType]bool) - - for _, dst := range dsts { - for _, ext := range el { - if ext.ExtensionType == dst.Type() { - if found[dst.Type()] { - return nil, fmt.Errorf("Duplicate extension of type [%v]", dst.Type()) - } - - err := safeUnmarshal(dst, ext.ExtensionData) - if err != nil { - return nil, err - } - - found[dst.Type()] = true - } - } - } - - return found, nil -} - -func (el ExtensionList) Find(dst ExtensionBody) (bool, error) { - for _, ext := range el { - if ext.ExtensionType == dst.Type() { - err := safeUnmarshal(dst, ext.ExtensionData) - if err != nil { - return true, err - } - return true, nil - } - } - return false, nil -} - -// struct { -// NameType name_type; -// select (name_type) { -// case host_name: HostName; -// } name; -// } ServerName; -// -// enum { -// host_name(0), (255) -// } NameType; -// -// opaque HostName<1..2^16-1>; -// -// struct { -// ServerName server_name_list<1..2^16-1> -// } ServerNameList; -// -// But we only care about the case where there's a single DNS hostname. We -// will never create anything else, and throw if we receive something else -// -// 2 1 2 -// | listLen | NameType | nameLen | name | -type ServerNameExtension string - -type serverNameInner struct { - NameType uint8 - HostName []byte `tls:"head=2,min=1"` -} - -type serverNameListInner struct { - ServerNameList []serverNameInner `tls:"head=2,min=1"` -} - -func (sni ServerNameExtension) Type() ExtensionType { - return ExtensionTypeServerName -} - -func (sni ServerNameExtension) Marshal() ([]byte, error) { - list := serverNameListInner{ - ServerNameList: []serverNameInner{{ - NameType: 0x00, // host_name - HostName: []byte(sni), - }}, - } - - return syntax.Marshal(list) -} - -func (sni *ServerNameExtension) Unmarshal(data []byte) (int, error) { - var list serverNameListInner - read, err := syntax.Unmarshal(data, &list) - if err != nil { - return 0, err - } - - // Syntax requires at least one entry - // Entries beyond the first are ignored - if nameType := list.ServerNameList[0].NameType; nameType != 0x00 { - return 0, fmt.Errorf("tls.servername: Unsupported name type [%x]", nameType) - } - - *sni = ServerNameExtension(list.ServerNameList[0].HostName) - return read, nil -} - -// struct { -// NamedGroup group; -// opaque key_exchange<1..2^16-1>; -// } KeyShareEntry; -// -// struct { -// select (Handshake.msg_type) { -// case client_hello: -// KeyShareEntry client_shares<0..2^16-1>; -// -// case hello_retry_request: -// NamedGroup selected_group; -// -// case server_hello: -// KeyShareEntry server_share; -// }; -// } KeyShare; -type KeyShareEntry struct { - Group NamedGroup - KeyExchange []byte `tls:"head=2,min=1"` -} - -func (kse KeyShareEntry) SizeValid() bool { - return len(kse.KeyExchange) == keyExchangeSizeFromNamedGroup(kse.Group) -} - -type KeyShareExtension struct { - HandshakeType HandshakeType - SelectedGroup NamedGroup - Shares []KeyShareEntry -} - -type KeyShareClientHelloInner struct { - ClientShares []KeyShareEntry `tls:"head=2,min=0"` -} -type KeyShareHelloRetryInner struct { - SelectedGroup NamedGroup -} -type KeyShareServerHelloInner struct { - ServerShare KeyShareEntry -} - -func (ks KeyShareExtension) Type() ExtensionType { - return ExtensionTypeKeyShare -} - -func (ks KeyShareExtension) Marshal() ([]byte, error) { - switch ks.HandshakeType { - case HandshakeTypeClientHello: - for _, share := range ks.Shares { - if !share.SizeValid() { - return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") - } - } - return syntax.Marshal(KeyShareClientHelloInner{ks.Shares}) - - case HandshakeTypeHelloRetryRequest: - if len(ks.Shares) > 0 { - return nil, fmt.Errorf("tls.keyshare: Key shares not allowed for HelloRetryRequest") - } - - return syntax.Marshal(KeyShareHelloRetryInner{ks.SelectedGroup}) - - case HandshakeTypeServerHello: - if len(ks.Shares) != 1 { - return nil, fmt.Errorf("tls.keyshare: Server must send exactly one key share") - } - - if !ks.Shares[0].SizeValid() { - return nil, fmt.Errorf("tls.keyshare: Key share has wrong size for group") - } - - return syntax.Marshal(KeyShareServerHelloInner{ks.Shares[0]}) - - default: - return nil, fmt.Errorf("tls.keyshare: Handshake type not allowed") - } -} - -func (ks *KeyShareExtension) Unmarshal(data []byte) (int, error) { - switch ks.HandshakeType { - case HandshakeTypeClientHello: - var inner KeyShareClientHelloInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - for _, share := range inner.ClientShares { - if !share.SizeValid() { - return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") - } - } - - ks.Shares = inner.ClientShares - return read, nil - - case HandshakeTypeHelloRetryRequest: - var inner KeyShareHelloRetryInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - ks.SelectedGroup = inner.SelectedGroup - return read, nil - - case HandshakeTypeServerHello: - var inner KeyShareServerHelloInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - if !inner.ServerShare.SizeValid() { - return 0, fmt.Errorf("tls.keyshare: Key share has wrong size for group") - } - - ks.Shares = []KeyShareEntry{inner.ServerShare} - return read, nil - - default: - return 0, fmt.Errorf("tls.keyshare: Handshake type not allowed") - } -} - -// struct { -// NamedGroup named_group_list<2..2^16-1>; -// } NamedGroupList; -type SupportedGroupsExtension struct { - Groups []NamedGroup `tls:"head=2,min=2"` -} - -func (sg SupportedGroupsExtension) Type() ExtensionType { - return ExtensionTypeSupportedGroups -} - -func (sg SupportedGroupsExtension) Marshal() ([]byte, error) { - return syntax.Marshal(sg) -} - -func (sg *SupportedGroupsExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sg) -} - -// struct { -// SignatureScheme supported_signature_algorithms<2..2^16-2>; -// } SignatureSchemeList -type SignatureAlgorithmsExtension struct { - Algorithms []SignatureScheme `tls:"head=2,min=2"` -} - -func (sa SignatureAlgorithmsExtension) Type() ExtensionType { - return ExtensionTypeSignatureAlgorithms -} - -func (sa SignatureAlgorithmsExtension) Marshal() ([]byte, error) { - return syntax.Marshal(sa) -} - -func (sa *SignatureAlgorithmsExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sa) -} - -// struct { -// opaque identity<1..2^16-1>; -// uint32 obfuscated_ticket_age; -// } PskIdentity; -// -// opaque PskBinderEntry<32..255>; -// -// struct { -// select (Handshake.msg_type) { -// case client_hello: -// PskIdentity identities<7..2^16-1>; -// PskBinderEntry binders<33..2^16-1>; -// -// case server_hello: -// uint16 selected_identity; -// }; -// -// } PreSharedKeyExtension; -type PSKIdentity struct { - Identity []byte `tls:"head=2,min=1"` - ObfuscatedTicketAge uint32 -} - -type PSKBinderEntry struct { - Binder []byte `tls:"head=1,min=32"` -} - -type PreSharedKeyExtension struct { - HandshakeType HandshakeType - Identities []PSKIdentity - Binders []PSKBinderEntry - SelectedIdentity uint16 -} - -type preSharedKeyClientInner struct { - Identities []PSKIdentity `tls:"head=2,min=7"` - Binders []PSKBinderEntry `tls:"head=2,min=33"` -} - -type preSharedKeyServerInner struct { - SelectedIdentity uint16 -} - -func (psk PreSharedKeyExtension) Type() ExtensionType { - return ExtensionTypePreSharedKey -} - -func (psk PreSharedKeyExtension) Marshal() ([]byte, error) { - switch psk.HandshakeType { - case HandshakeTypeClientHello: - return syntax.Marshal(preSharedKeyClientInner{ - Identities: psk.Identities, - Binders: psk.Binders, - }) - - case HandshakeTypeServerHello: - if len(psk.Identities) > 0 || len(psk.Binders) > 0 { - return nil, fmt.Errorf("tls.presharedkey: Server can only provide an index") - } - return syntax.Marshal(preSharedKeyServerInner{psk.SelectedIdentity}) - - default: - return nil, fmt.Errorf("tls.presharedkey: Handshake type not supported") - } -} - -func (psk *PreSharedKeyExtension) Unmarshal(data []byte) (int, error) { - switch psk.HandshakeType { - case HandshakeTypeClientHello: - var inner preSharedKeyClientInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - if len(inner.Identities) != len(inner.Binders) { - return 0, fmt.Errorf("Lengths of identities and binders not equal") - } - - psk.Identities = inner.Identities - psk.Binders = inner.Binders - return read, nil - - case HandshakeTypeServerHello: - var inner preSharedKeyServerInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - psk.SelectedIdentity = inner.SelectedIdentity - return read, nil - - default: - return 0, fmt.Errorf("tls.presharedkey: Handshake type not supported") - } -} - -func (psk PreSharedKeyExtension) HasIdentity(id []byte) ([]byte, bool) { - for i, localID := range psk.Identities { - if bytes.Equal(localID.Identity, id) { - return psk.Binders[i].Binder, true - } - } - return nil, false -} - -// enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode; -// -// struct { -// PskKeyExchangeMode ke_modes<1..255>; -// } PskKeyExchangeModes; -type PSKKeyExchangeModesExtension struct { - KEModes []PSKKeyExchangeMode `tls:"head=1,min=1"` -} - -func (pkem PSKKeyExchangeModesExtension) Type() ExtensionType { - return ExtensionTypePSKKeyExchangeModes -} - -func (pkem PSKKeyExchangeModesExtension) Marshal() ([]byte, error) { - return syntax.Marshal(pkem) -} - -func (pkem *PSKKeyExchangeModesExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, pkem) -} - -// struct { -// } EarlyDataIndication; - -type EarlyDataExtension struct{} - -func (ed EarlyDataExtension) Type() ExtensionType { - return ExtensionTypeEarlyData -} - -func (ed EarlyDataExtension) Marshal() ([]byte, error) { - return []byte{}, nil -} - -func (ed *EarlyDataExtension) Unmarshal(data []byte) (int, error) { - return 0, nil -} - -// struct { -// uint32 max_early_data_size; -// } TicketEarlyDataInfo; - -type TicketEarlyDataInfoExtension struct { - MaxEarlyDataSize uint32 -} - -func (tedi TicketEarlyDataInfoExtension) Type() ExtensionType { - return ExtensionTypeTicketEarlyDataInfo -} - -func (tedi TicketEarlyDataInfoExtension) Marshal() ([]byte, error) { - return syntax.Marshal(tedi) -} - -func (tedi *TicketEarlyDataInfoExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, tedi) -} - -// opaque ProtocolName<1..2^8-1>; -// -// struct { -// ProtocolName protocol_name_list<2..2^16-1> -// } ProtocolNameList; -type ALPNExtension struct { - Protocols []string -} - -type protocolNameInner struct { - Name []byte `tls:"head=1,min=1"` -} - -type alpnExtensionInner struct { - Protocols []protocolNameInner `tls:"head=2,min=2"` -} - -func (alpn ALPNExtension) Type() ExtensionType { - return ExtensionTypeALPN -} - -func (alpn ALPNExtension) Marshal() ([]byte, error) { - protocols := make([]protocolNameInner, len(alpn.Protocols)) - for i, protocol := range alpn.Protocols { - protocols[i] = protocolNameInner{[]byte(protocol)} - } - return syntax.Marshal(alpnExtensionInner{protocols}) -} - -func (alpn *ALPNExtension) Unmarshal(data []byte) (int, error) { - var inner alpnExtensionInner - read, err := syntax.Unmarshal(data, &inner) - - if err != nil { - return 0, err - } - - alpn.Protocols = make([]string, len(inner.Protocols)) - for i, protocol := range inner.Protocols { - alpn.Protocols[i] = string(protocol.Name) - } - return read, nil -} - -// struct { -// ProtocolVersion versions<2..254>; -// } SupportedVersions; -type SupportedVersionsExtension struct { - HandshakeType HandshakeType - Versions []uint16 -} - -type SupportedVersionsClientHelloInner struct { - Versions []uint16 `tls:"head=1,min=2,max=254"` -} - -type SupportedVersionsServerHelloInner struct { - Version uint16 -} - -func (sv SupportedVersionsExtension) Type() ExtensionType { - return ExtensionTypeSupportedVersions -} - -func (sv SupportedVersionsExtension) Marshal() ([]byte, error) { - switch sv.HandshakeType { - case HandshakeTypeClientHello: - return syntax.Marshal(SupportedVersionsClientHelloInner{sv.Versions}) - case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest: - return syntax.Marshal(SupportedVersionsServerHelloInner{sv.Versions[0]}) - default: - return nil, fmt.Errorf("tls.supported_versions: Handshake type not allowed") - } -} - -func (sv *SupportedVersionsExtension) Unmarshal(data []byte) (int, error) { - switch sv.HandshakeType { - case HandshakeTypeClientHello: - var inner SupportedVersionsClientHelloInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - sv.Versions = inner.Versions - return read, nil - - case HandshakeTypeServerHello, HandshakeTypeHelloRetryRequest: - var inner SupportedVersionsServerHelloInner - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - sv.Versions = []uint16{inner.Version} - return read, nil - - default: - return 0, fmt.Errorf("tls.supported_versions: Handshake type not allowed") - } -} - -// struct { -// opaque cookie<1..2^16-1>; -// } Cookie; -type CookieExtension struct { - Cookie []byte `tls:"head=2,min=1"` -} - -func (c CookieExtension) Type() ExtensionType { - return ExtensionTypeCookie -} - -func (c CookieExtension) Marshal() ([]byte, error) { - return syntax.Marshal(c) -} - -func (c *CookieExtension) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, c) -} diff --git a/vendor/github.com/bifurcation/mint/ffdhe.go b/vendor/github.com/bifurcation/mint/ffdhe.go deleted file mode 100644 index 59d1f7f9..00000000 --- a/vendor/github.com/bifurcation/mint/ffdhe.go +++ /dev/null @@ -1,147 +0,0 @@ -package mint - -import ( - "encoding/hex" - "math/big" -) - -var ( - finiteFieldPrime2048hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B423861285C97FFFFFFFFFFFFFFFF" - finiteFieldPrime2048bytes, _ = hex.DecodeString(finiteFieldPrime2048hex) - finiteFieldPrime2048 = big.NewInt(0).SetBytes(finiteFieldPrime2048bytes) - - finiteFieldPrime3072hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + - "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + - "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + - "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + - "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + - "3C1B20EE3FD59D7C25E41D2B66C62E37FFFFFFFFFFFFFFFF" - finiteFieldPrime3072bytes, _ = hex.DecodeString(finiteFieldPrime3072hex) - finiteFieldPrime3072 = big.NewInt(0).SetBytes(finiteFieldPrime3072bytes) - - finiteFieldPrime4096hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + - "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + - "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + - "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + - "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + - "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + - "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + - "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + - "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + - "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + - "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E655F6A" + - "FFFFFFFFFFFFFFFF" - finiteFieldPrime4096bytes, _ = hex.DecodeString(finiteFieldPrime4096hex) - finiteFieldPrime4096 = big.NewInt(0).SetBytes(finiteFieldPrime4096bytes) - - finiteFieldPrime6144hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + - "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + - "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + - "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + - "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + - "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + - "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + - "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + - "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + - "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + - "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + - "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + - "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + - "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + - "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + - "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + - "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + - "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + - "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + - "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + - "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + - "A41D570D7938DAD4A40E329CD0E40E65FFFFFFFFFFFFFFFF" - finiteFieldPrime6144bytes, _ = hex.DecodeString(finiteFieldPrime6144hex) - finiteFieldPrime6144 = big.NewInt(0).SetBytes(finiteFieldPrime6144bytes) - - finiteFieldPrime8192hex = "FFFFFFFFFFFFFFFFADF85458A2BB4A9AAFDC5620273D3CF1" + - "D8B9C583CE2D3695A9E13641146433FBCC939DCE249B3EF9" + - "7D2FE363630C75D8F681B202AEC4617AD3DF1ED5D5FD6561" + - "2433F51F5F066ED0856365553DED1AF3B557135E7F57C935" + - "984F0C70E0E68B77E2A689DAF3EFE8721DF158A136ADE735" + - "30ACCA4F483A797ABC0AB182B324FB61D108A94BB2C8E3FB" + - "B96ADAB760D7F4681D4F42A3DE394DF4AE56EDE76372BB19" + - "0B07A7C8EE0A6D709E02FCE1CDF7E2ECC03404CD28342F61" + - "9172FE9CE98583FF8E4F1232EEF28183C3FE3B1B4C6FAD73" + - "3BB5FCBC2EC22005C58EF1837D1683B2C6F34A26C1B2EFFA" + - "886B4238611FCFDCDE355B3B6519035BBC34F4DEF99C0238" + - "61B46FC9D6E6C9077AD91D2691F7F7EE598CB0FAC186D91C" + - "AEFE130985139270B4130C93BC437944F4FD4452E2D74DD3" + - "64F2E21E71F54BFF5CAE82AB9C9DF69EE86D2BC522363A0D" + - "ABC521979B0DEADA1DBF9A42D5C4484E0ABCD06BFA53DDEF" + - "3C1B20EE3FD59D7C25E41D2B669E1EF16E6F52C3164DF4FB" + - "7930E9E4E58857B6AC7D5F42D69F6D187763CF1D55034004" + - "87F55BA57E31CC7A7135C886EFB4318AED6A1E012D9E6832" + - "A907600A918130C46DC778F971AD0038092999A333CB8B7A" + - "1A1DB93D7140003C2A4ECEA9F98D0ACC0A8291CDCEC97DCF" + - "8EC9B55A7F88A46B4DB5A851F44182E1C68A007E5E0DD902" + - "0BFD64B645036C7A4E677D2C38532A3A23BA4442CAF53EA6" + - "3BB454329B7624C8917BDD64B1C0FD4CB38E8C334C701C3A" + - "CDAD0657FCCFEC719B1F5C3E4E46041F388147FB4CFDB477" + - "A52471F7A9A96910B855322EDB6340D8A00EF092350511E3" + - "0ABEC1FFF9E3A26E7FB29F8C183023C3587E38DA0077D9B4" + - "763E4E4B94B2BBC194C6651E77CAF992EEAAC0232A281BF6" + - "B3A739C1226116820AE8DB5847A67CBEF9C9091B462D538C" + - "D72B03746AE77F5E62292C311562A846505DC82DB854338A" + - "E49F5235C95B91178CCF2DD5CACEF403EC9D1810C6272B04" + - "5B3B71F9DC6B80D63FDD4A8E9ADB1E6962A69526D43161C1" + - "A41D570D7938DAD4A40E329CCFF46AAA36AD004CF600C838" + - "1E425A31D951AE64FDB23FCEC9509D43687FEB69EDD1CC5E" + - "0B8CC3BDF64B10EF86B63142A3AB8829555B2F747C932665" + - "CB2C0F1CC01BD70229388839D2AF05E454504AC78B758282" + - "2846C0BA35C35F5C59160CC046FD8251541FC68C9C86B022" + - "BB7099876A460E7451A8A93109703FEE1C217E6C3826E52C" + - "51AA691E0E423CFC99E9E31650C1217B624816CDAD9A95F9" + - "D5B8019488D9C0A0A1FE3075A577E23183F81D4A3F2FA457" + - "1EFC8CE0BA8A4FE8B6855DFE72B0A66EDED2FBABFBE58A30" + - "FAFABE1C5D71A87E2F741EF8C1FE86FEA6BBFDE530677F0D" + - "97D11D49F7A8443D0822E506A9F4614E011E2A94838FF88C" + - "D68C8BB7C5C6424CFFFFFFFFFFFFFFFF" - finiteFieldPrime8192bytes, _ = hex.DecodeString(finiteFieldPrime8192hex) - finiteFieldPrime8192 = big.NewInt(0).SetBytes(finiteFieldPrime8192bytes) -) diff --git a/vendor/github.com/bifurcation/mint/frame-reader.go b/vendor/github.com/bifurcation/mint/frame-reader.go deleted file mode 100644 index 4ccfc23f..00000000 --- a/vendor/github.com/bifurcation/mint/frame-reader.go +++ /dev/null @@ -1,98 +0,0 @@ -// Read a generic "framed" packet consisting of a header and a -// This is used for both TLS Records and TLS Handshake Messages -package mint - -type framing interface { - headerLen() int - defaultReadLen() int - frameLen(hdr []byte) (int, error) -} - -const ( - kFrameReaderHdr = 0 - kFrameReaderBody = 1 -) - -type frameNextAction func(f *frameReader) error - -type frameReader struct { - details framing - state uint8 - header []byte - body []byte - working []byte - writeOffset int - remainder []byte -} - -func newFrameReader(d framing) *frameReader { - hdr := make([]byte, d.headerLen()) - return &frameReader{ - d, - kFrameReaderHdr, - hdr, - nil, - hdr, - 0, - nil, - } -} - -func dup(a []byte) []byte { - r := make([]byte, len(a)) - copy(r, a) - return r -} - -func (f *frameReader) needed() int { - tmp := (len(f.working) - f.writeOffset) - len(f.remainder) - if tmp < 0 { - return 0 - } - return tmp -} - -func (f *frameReader) addChunk(in []byte) { - // Append to the buffer. - logf(logTypeFrameReader, "Appending %v", len(in)) - f.remainder = append(f.remainder, in...) -} - -func (f *frameReader) process() (hdr []byte, body []byte, err error) { - for f.needed() == 0 { - logf(logTypeFrameReader, "%v bytes needed for next block", len(f.working)-f.writeOffset) - // Fill out our working block - copied := copy(f.working[f.writeOffset:], f.remainder) - f.remainder = f.remainder[copied:] - f.writeOffset += copied - if f.writeOffset < len(f.working) { - logf(logTypeVerbose, "Read would have blocked 1") - return nil, nil, AlertWouldBlock - } - // Reset the write offset, because we are now full. - f.writeOffset = 0 - - // We have read a full frame - if f.state == kFrameReaderBody { - logf(logTypeFrameReader, "Returning frame hdr=%#x len=%d buffered=%d", f.header, len(f.body), len(f.remainder)) - f.state = kFrameReaderHdr - f.working = f.header - return dup(f.header), dup(f.body), nil - } - - // We have read the header - bodyLen, err := f.details.frameLen(f.header) - if err != nil { - return nil, nil, err - } - logf(logTypeFrameReader, "Processed header, body len = %v", bodyLen) - - f.body = make([]byte, bodyLen) - f.working = f.body - f.writeOffset = 0 - f.state = kFrameReaderBody - } - - logf(logTypeVerbose, "Read would have blocked 2") - return nil, nil, AlertWouldBlock -} diff --git a/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/bifurcation/mint/handshake-layer.go deleted file mode 100644 index de17b30b..00000000 --- a/vendor/github.com/bifurcation/mint/handshake-layer.go +++ /dev/null @@ -1,551 +0,0 @@ -package mint - -import ( - "fmt" - "io" - "net" -) - -const ( - handshakeHeaderLenTLS = 4 // handshake message header length - handshakeHeaderLenDTLS = 12 // handshake message header length - maxHandshakeMessageLen = 1 << 24 // max handshake message length -) - -// struct { -// HandshakeType msg_type; /* handshake type */ -// uint24 length; /* bytes in message */ -// select (HandshakeType) { -// ... -// } body; -// } Handshake; -// -// We do the select{...} part in a different layer, so we treat the -// actual message body as opaque: -// -// struct { -// HandshakeType msg_type; -// opaque msg<0..2^24-1> -// } Handshake; -// -type HandshakeMessage struct { - msgType HandshakeType - seq uint32 - body []byte - datagram bool - offset uint32 // Used for DTLS - length uint32 - cipher *cipherState -} - -// Note: This could be done with the `syntax` module, using the simplified -// syntax as discussed above. However, since this is so simple, there's not -// much benefit to doing so. -// When datagram is set, we marshal this as a whole DTLS record. -func (hm *HandshakeMessage) Marshal() []byte { - if hm == nil { - return []byte{} - } - - fragLen := len(hm.body) - var data []byte - - if hm.datagram { - data = make([]byte, handshakeHeaderLenDTLS+fragLen) - } else { - data = make([]byte, handshakeHeaderLenTLS+fragLen) - } - tmp := data - tmp = encodeUint(uint64(hm.msgType), 1, tmp) - tmp = encodeUint(uint64(hm.length), 3, tmp) - if hm.datagram { - tmp = encodeUint(uint64(hm.seq), 2, tmp) - tmp = encodeUint(uint64(hm.offset), 3, tmp) - tmp = encodeUint(uint64(fragLen), 3, tmp) - } - copy(tmp, hm.body) - return data -} - -func (hm HandshakeMessage) ToBody() (HandshakeMessageBody, error) { - logf(logTypeHandshake, "HandshakeMessage.toBody [%d] [%x]", hm.msgType, hm.body) - - var body HandshakeMessageBody - switch hm.msgType { - case HandshakeTypeClientHello: - body = new(ClientHelloBody) - case HandshakeTypeServerHello: - body = new(ServerHelloBody) - case HandshakeTypeEncryptedExtensions: - body = new(EncryptedExtensionsBody) - case HandshakeTypeCertificate: - body = new(CertificateBody) - case HandshakeTypeCertificateRequest: - body = new(CertificateRequestBody) - case HandshakeTypeCertificateVerify: - body = new(CertificateVerifyBody) - case HandshakeTypeFinished: - body = &FinishedBody{VerifyDataLen: len(hm.body)} - case HandshakeTypeNewSessionTicket: - body = new(NewSessionTicketBody) - case HandshakeTypeKeyUpdate: - body = new(KeyUpdateBody) - case HandshakeTypeEndOfEarlyData: - body = new(EndOfEarlyDataBody) - default: - return body, fmt.Errorf("tls.handshakemessage: Unsupported body type") - } - - err := safeUnmarshal(body, hm.body) - return body, err -} - -func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*HandshakeMessage, error) { - data, err := body.Marshal() - if err != nil { - return nil, err - } - - m := &HandshakeMessage{ - msgType: body.Type(), - body: data, - seq: h.msgSeq, - datagram: h.datagram, - length: uint32(len(data)), - } - h.msgSeq++ - return m, nil -} - -type HandshakeLayer struct { - ctx *HandshakeContext // The handshake we are attached to - nonblocking bool // Should we operate in nonblocking mode - conn *RecordLayer // Used for reading/writing records - frame *frameReader // The buffered frame reader - datagram bool // Is this DTLS? - msgSeq uint32 // The DTLS message sequence number - queued []*HandshakeMessage // In/out queue - sent []*HandshakeMessage // Sent messages for DTLS - recvdRecords []uint64 // Records we have received. - maxFragmentLen int -} - -type handshakeLayerFrameDetails struct { - datagram bool -} - -func (d handshakeLayerFrameDetails) headerLen() int { - if d.datagram { - return handshakeHeaderLenDTLS - } - return handshakeHeaderLenTLS -} - -func (d handshakeLayerFrameDetails) defaultReadLen() int { - return d.headerLen() + maxFragmentLen -} - -func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { - logf(logTypeIO, "Header=%x", hdr) - // The length of this fragment (as opposed to the message) - // is always the last three bytes for both TLS and DTLS - val, _ := decodeUint(hdr[len(hdr)-3:], 3) - return int(val), nil -} - -func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { - h := HandshakeLayer{} - h.ctx = c - h.conn = r - h.datagram = false - h.frame = newFrameReader(&handshakeLayerFrameDetails{false}) - h.maxFragmentLen = maxFragmentLen - return &h -} - -func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { - h := HandshakeLayer{} - h.ctx = c - h.conn = r - h.datagram = true - h.frame = newFrameReader(&handshakeLayerFrameDetails{true}) - h.maxFragmentLen = initialMtu // Not quite right - return &h -} - -func (h *HandshakeLayer) readRecord() error { - logf(logTypeVerbose, "Trying to read record") - pt, err := h.conn.readRecordAnyEpoch() - if err != nil { - return err - } - - switch pt.contentType { - case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck: - default: - return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) - } - - if pt.contentType == RecordTypeAck { - if !h.datagram { - return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS") - } - logf(logTypeIO, "read ACK") - return h.ctx.processAck(pt.fragment) - } - - if pt.contentType == RecordTypeAlert { - logf(logTypeIO, "read alert %v", pt.fragment[1]) - if len(pt.fragment) < 2 { - h.sendAlert(AlertUnexpectedMessage) - return io.EOF - } - return Alert(pt.fragment[1]) - } - - assert(h.ctx.hIn.conn != nil) - if pt.epoch != h.ctx.hIn.conn.cipher.epoch { - // This is out of order but we're dropping it. - // TODO(ekr@rtfm.com): If server, need to retransmit Finished. - if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData { - return nil - } - - // Anything else shouldn't happen. - return AlertIllegalParameter - } - - h.recvdRecords = append(h.recvdRecords, pt.seq) - h.frame.addChunk(pt.fragment) - - return nil -} - -// sendAlert sends a TLS alert message. -func (h *HandshakeLayer) sendAlert(err Alert) error { - tmp := make([]byte, 2) - tmp[0] = AlertLevelError - tmp[1] = byte(err) - h.conn.WriteRecord(&TLSPlaintext{ - contentType: RecordTypeAlert, - fragment: tmp}, - ) - - // closeNotify is a special case in that it isn't an error: - if err != AlertCloseNotify { - return &net.OpError{Op: "local error", Err: err} - } - return nil -} - -func (h *HandshakeLayer) noteMessageDelivered(seq uint32) { - h.msgSeq = seq + 1 - var i int - var m *HandshakeMessage - for i, m = range h.queued { - if m.seq > seq { - break - } - } - h.queued = h.queued[i:] -} - -func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) { - if hm.seq < h.msgSeq { - return nil, nil - } - - // TODO(ekr@rtfm.com): Send an ACK immediately if we got something - // out of order. - h.ctx.receivedHandshakeMessage() - - if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { - // TODO(ekr@rtfm.com): Check the length? - // This is complete. - h.noteMessageDelivered(hm.seq) - return hm, nil - } - - // Now insert sorted. - var i int - for i = 0; i < len(h.queued); i++ { - f := h.queued[i] - if hm.seq < f.seq { - break - } - if hm.offset < f.offset { - break - } - } - tmp := make([]*HandshakeMessage, 0, len(h.queued)+1) - tmp = append(tmp, h.queued[:i]...) - tmp = append(tmp, hm) - tmp = append(tmp, h.queued[i:]...) - h.queued = tmp - - return h.checkMessageAvailable() -} - -func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { - if len(h.queued) == 0 { - return nil, nil - } - - hm := h.queued[0] - if hm.seq != h.msgSeq { - return nil, nil - } - - if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { - // TODO(ekr@rtfm.com): Check the length? - // This is complete. - h.noteMessageDelivered(hm.seq) - return hm, nil - } - - // OK, this at least might complete the message. - end := uint32(0) - buf := make([]byte, hm.length) - - for _, f := range h.queued { - // Out of fragments - if f.seq > hm.seq { - break - } - - if f.length != uint32(len(buf)) { - return nil, fmt.Errorf("Mismatched DTLS length") - } - - if f.offset > end { - break - } - - if f.offset+uint32(len(f.body)) > end { - // OK, this is adding something we don't know about - copy(buf[f.offset:], f.body) - end = f.offset + uint32(len(f.body)) - if end == hm.length { - h2 := *hm - h2.offset = 0 - h2.body = buf - h.noteMessageDelivered(hm.seq) - return &h2, nil - } - } - - } - - return nil, nil -} - -func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { - var hdr, body []byte - var err error - - hm, err := h.checkMessageAvailable() - if err != nil { - return nil, err - } - if hm != nil { - return hm, nil - } - for { - logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) - if h.frame.needed() > 0 { - logf(logTypeVerbose, "Trying to read a new record") - err = h.readRecord() - - if err != nil && (h.nonblocking || err != AlertWouldBlock) { - return nil, err - } - } - - hdr, body, err = h.frame.process() - if err == nil { - break - } - if err != nil && (h.nonblocking || err != AlertWouldBlock) { - return nil, err - } - } - - logf(logTypeHandshake, "read handshake message") - - hm = &HandshakeMessage{} - hm.msgType = HandshakeType(hdr[0]) - hm.datagram = h.datagram - hm.body = make([]byte, len(body)) - copy(hm.body, body) - logf(logTypeHandshake, "Read message with type: %v", hm.msgType) - if h.datagram { - tmp, hdr := decodeUint(hdr[1:], 3) - hm.length = uint32(tmp) - tmp, hdr = decodeUint(hdr, 2) - hm.seq = uint32(tmp) - tmp, hdr = decodeUint(hdr, 3) - hm.offset = uint32(tmp) - - return h.newFragmentReceived(hm) - } - - hm.length = uint32(len(body)) - return hm, nil -} - -func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error { - hm.cipher = h.conn.cipher - h.queued = append(h.queued, hm) - return nil -} - -func (h *HandshakeLayer) SendQueuedMessages() (int, error) { - logf(logTypeHandshake, "Sending outgoing messages") - count, err := h.WriteMessages(h.queued) - if !h.datagram { - h.ClearQueuedMessages() - } - return count, err -} - -func (h *HandshakeLayer) ClearQueuedMessages() { - logf(logTypeHandshake, "Clearing outgoing hs message queue") - h.queued = nil -} - -func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) { - var buf []byte - - // Figure out if we're going to want the full header or just - // the body - hdrlen := 0 - if hm.datagram { - hdrlen = handshakeHeaderLenDTLS - } else if start == 0 { - hdrlen = handshakeHeaderLenTLS - } - - // Compute the amount of body we can fit in - room -= hdrlen - if room == 0 { - // This works because we are doing one record per - // message - panic("Too short max fragment len") - } - bodylen := len(hm.body) - start - if bodylen > room { - bodylen = room - } - body := hm.body[start : start+bodylen] - - // Now see if this chunk has been ACKed. This doesn't produce ideal - // retransmission but is simple. - if h.ctx.fragmentAcked(hm.seq, start, bodylen) { - logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen) - return false, start + bodylen, nil - } - - // Encode the data. - if hdrlen > 0 { - hm2 := *hm - hm2.offset = uint32(start) - hm2.body = body - buf = hm2.Marshal() - hm = &hm2 - } else { - buf = body - } - - if h.datagram { - // Remember that we sent this. - h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{ - hm.seq, - start, - len(body), - h.conn.cipher.combineSeq(true), - false, - }) - } - return true, start + bodylen, h.conn.writeRecordWithPadding( - &TLSPlaintext{ - contentType: RecordTypeHandshake, - fragment: buf, - }, - hm.cipher, 0) -} - -func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) { - start := int(0) - - if len(hm.body) > maxHandshakeMessageLen { - return 0, fmt.Errorf("Tried to write a handshake message that's too long") - } - - written := 0 - wrote := false - - // Always make one pass through to allow EOED (which is empty). - for { - var err error - wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen) - if err != nil { - return 0, err - } - if wrote { - written++ - } - if start >= len(hm.body) { - break - } - } - - return written, nil -} - -func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) { - written := 0 - for _, hm := range hms { - logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) - - wrote, err := h.WriteMessage(hm) - if err != nil { - return 0, err - } - written += wrote - } - return written, nil -} - -func encodeUint(v uint64, size int, out []byte) []byte { - for i := size - 1; i >= 0; i-- { - out[i] = byte(v & 0xff) - v >>= 8 - } - return out[size:] -} - -func decodeUint(in []byte, size int) (uint64, []byte) { - val := uint64(0) - - for i := 0; i < size; i++ { - val <<= 8 - val += uint64(in[i]) - } - return val, in[size:] -} - -type marshalledPDU interface { - Marshal() ([]byte, error) - Unmarshal(data []byte) (int, error) -} - -func safeUnmarshal(pdu marshalledPDU, data []byte) error { - read, err := pdu.Unmarshal(data) - if err != nil { - return err - } - if len(data) != read { - return fmt.Errorf("Invalid encoding: Extra data not consumed") - } - return nil -} diff --git a/vendor/github.com/bifurcation/mint/handshake-messages.go b/vendor/github.com/bifurcation/mint/handshake-messages.go deleted file mode 100644 index 5a229f1d..00000000 --- a/vendor/github.com/bifurcation/mint/handshake-messages.go +++ /dev/null @@ -1,481 +0,0 @@ -package mint - -import ( - "bytes" - "crypto" - "crypto/x509" - "encoding/binary" - "fmt" - - "github.com/bifurcation/mint/syntax" -) - -type HandshakeMessageBody interface { - Type() HandshakeType - Marshal() ([]byte, error) - Unmarshal(data []byte) (int, error) -} - -// struct { -// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ -// Random random; -// opaque legacy_session_id<0..32>; -// CipherSuite cipher_suites<2..2^16-2>; -// opaque legacy_compression_methods<1..2^8-1>; -// Extension extensions<0..2^16-1>; -// } ClientHello; -type ClientHelloBody struct { - LegacyVersion uint16 - Random [32]byte - LegacySessionID []byte - CipherSuites []CipherSuite - Extensions ExtensionList -} - -type clientHelloBodyInnerTLS struct { - LegacyVersion uint16 - Random [32]byte - LegacySessionID []byte `tls:"head=1,max=32"` - CipherSuites []CipherSuite `tls:"head=2,min=2"` - LegacyCompressionMethods []byte `tls:"head=1,min=1"` - Extensions []Extension `tls:"head=2"` -} - -type clientHelloBodyInnerDTLS struct { - LegacyVersion uint16 - Random [32]byte - LegacySessionID []byte `tls:"head=1,max=32"` - EmptyCookie uint8 - CipherSuites []CipherSuite `tls:"head=2,min=2"` - LegacyCompressionMethods []byte `tls:"head=1,min=1"` - Extensions []Extension `tls:"head=2"` -} - -func (ch ClientHelloBody) Type() HandshakeType { - return HandshakeTypeClientHello -} - -func (ch ClientHelloBody) Marshal() ([]byte, error) { - if ch.LegacyVersion == tls12Version { - return syntax.Marshal(clientHelloBodyInnerTLS{ - LegacyVersion: ch.LegacyVersion, - Random: ch.Random, - LegacySessionID: []byte{}, - CipherSuites: ch.CipherSuites, - LegacyCompressionMethods: []byte{0}, - Extensions: ch.Extensions, - }) - } else { - return syntax.Marshal(clientHelloBodyInnerDTLS{ - LegacyVersion: ch.LegacyVersion, - Random: ch.Random, - LegacySessionID: []byte{}, - CipherSuites: ch.CipherSuites, - LegacyCompressionMethods: []byte{0}, - Extensions: ch.Extensions, - }) - } - -} - -func (ch *ClientHelloBody) Unmarshal(data []byte) (int, error) { - var read int - var err error - - // Note that this might be 0, in which case we do TLS. That - // makes the tests easier. - if ch.LegacyVersion != dtls12WireVersion { - var inner clientHelloBodyInnerTLS - read, err = syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { - return 0, fmt.Errorf("tls.clienthello: Invalid compression method") - } - - ch.LegacyVersion = inner.LegacyVersion - ch.Random = inner.Random - ch.LegacySessionID = inner.LegacySessionID - ch.CipherSuites = inner.CipherSuites - ch.Extensions = inner.Extensions - } else { - var inner clientHelloBodyInnerDTLS - read, err = syntax.Unmarshal(data, &inner) - if err != nil { - return 0, err - } - - if inner.EmptyCookie != 0 { - return 0, fmt.Errorf("tls.clienthello: Invalid cookie") - } - - if len(inner.LegacyCompressionMethods) != 1 || inner.LegacyCompressionMethods[0] != 0 { - return 0, fmt.Errorf("tls.clienthello: Invalid compression method") - } - - ch.LegacyVersion = inner.LegacyVersion - ch.Random = inner.Random - ch.LegacySessionID = inner.LegacySessionID - ch.CipherSuites = inner.CipherSuites - ch.Extensions = inner.Extensions - } - return read, nil -} - -// TODO: File a spec bug to clarify this -func (ch ClientHelloBody) Truncated() ([]byte, error) { - if len(ch.Extensions) == 0 { - return nil, fmt.Errorf("tls.clienthello.truncate: No extensions") - } - - pskExt := ch.Extensions[len(ch.Extensions)-1] - if pskExt.ExtensionType != ExtensionTypePreSharedKey { - return nil, fmt.Errorf("tls.clienthello.truncate: Last extension is not PSK") - } - - body, err := ch.Marshal() - if err != nil { - return nil, err - } - chm := &HandshakeMessage{ - msgType: ch.Type(), - body: body, - length: uint32(len(body)), - } - chData := chm.Marshal() - - psk := PreSharedKeyExtension{ - HandshakeType: HandshakeTypeClientHello, - } - _, err = psk.Unmarshal(pskExt.ExtensionData) - if err != nil { - return nil, err - } - - // Marshal just the binders so that we know how much to truncate - binders := struct { - Binders []PSKBinderEntry `tls:"head=2,min=33"` - }{Binders: psk.Binders} - binderData, _ := syntax.Marshal(binders) - binderLen := len(binderData) - - chLen := len(chData) - return chData[:chLen-binderLen], nil -} - -// struct { -// ProtocolVersion legacy_version = 0x0303; /* TLS v1.2 */ -// Random random; -// opaque legacy_session_id_echo<0..32>; -// CipherSuite cipher_suite; -// uint8 legacy_compression_method = 0; -// Extension extensions<6..2^16-1>; -// } ServerHello; -type ServerHelloBody struct { - Version uint16 - Random [32]byte - LegacySessionID []byte `tls:"head=1,max=32"` - CipherSuite CipherSuite - LegacyCompressionMethod uint8 - Extensions ExtensionList `tls:"head=2"` -} - -func (sh ServerHelloBody) Type() HandshakeType { - return HandshakeTypeServerHello -} - -func (sh ServerHelloBody) Marshal() ([]byte, error) { - return syntax.Marshal(sh) -} - -func (sh *ServerHelloBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, sh) -} - -// struct { -// opaque verify_data[verify_data_length]; -// } Finished; -// -// verifyDataLen is not a field in the TLS struct, but we add it here so -// that calling code can tell us how much data to expect when we marshal / -// unmarshal. (We could add this to the marshal/unmarshal methods, but let's -// try to keep the signature consistent for now.) -// -// For similar reasons, we don't use the `syntax` module here, because this -// struct doesn't map well to standard TLS presentation language concepts. -// -// TODO: File a spec bug -type FinishedBody struct { - VerifyDataLen int - VerifyData []byte -} - -func (fin FinishedBody) Type() HandshakeType { - return HandshakeTypeFinished -} - -func (fin FinishedBody) Marshal() ([]byte, error) { - if len(fin.VerifyData) != fin.VerifyDataLen { - return nil, fmt.Errorf("tls.finished: data length mismatch") - } - - body := make([]byte, len(fin.VerifyData)) - copy(body, fin.VerifyData) - return body, nil -} - -func (fin *FinishedBody) Unmarshal(data []byte) (int, error) { - if len(data) < fin.VerifyDataLen { - return 0, fmt.Errorf("tls.finished: Malformed finished; too short") - } - - fin.VerifyData = make([]byte, fin.VerifyDataLen) - copy(fin.VerifyData, data[:fin.VerifyDataLen]) - return fin.VerifyDataLen, nil -} - -// struct { -// Extension extensions<0..2^16-1>; -// } EncryptedExtensions; -// -// Marshal() and Unmarshal() are handled by ExtensionList -type EncryptedExtensionsBody struct { - Extensions ExtensionList `tls:"head=2"` -} - -func (ee EncryptedExtensionsBody) Type() HandshakeType { - return HandshakeTypeEncryptedExtensions -} - -func (ee EncryptedExtensionsBody) Marshal() ([]byte, error) { - return syntax.Marshal(ee) -} - -func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, ee) -} - -// opaque ASN1Cert<1..2^24-1>; -// -// struct { -// ASN1Cert cert_data; -// Extension extensions<0..2^16-1> -// } CertificateEntry; -// -// struct { -// opaque certificate_request_context<0..2^8-1>; -// CertificateEntry certificate_list<0..2^24-1>; -// } Certificate; -type CertificateEntry struct { - CertData *x509.Certificate - Extensions ExtensionList -} - -type CertificateBody struct { - CertificateRequestContext []byte - CertificateList []CertificateEntry -} - -type certificateEntryInner struct { - CertData []byte `tls:"head=3,min=1"` - Extensions ExtensionList `tls:"head=2"` -} - -type certificateBodyInner struct { - CertificateRequestContext []byte `tls:"head=1"` - CertificateList []certificateEntryInner `tls:"head=3"` -} - -func (c CertificateBody) Type() HandshakeType { - return HandshakeTypeCertificate -} - -func (c CertificateBody) Marshal() ([]byte, error) { - inner := certificateBodyInner{ - CertificateRequestContext: c.CertificateRequestContext, - CertificateList: make([]certificateEntryInner, len(c.CertificateList)), - } - - for i, entry := range c.CertificateList { - inner.CertificateList[i] = certificateEntryInner{ - CertData: entry.CertData.Raw, - Extensions: entry.Extensions, - } - } - - return syntax.Marshal(inner) -} - -func (c *CertificateBody) Unmarshal(data []byte) (int, error) { - inner := certificateBodyInner{} - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return read, err - } - - c.CertificateRequestContext = inner.CertificateRequestContext - c.CertificateList = make([]CertificateEntry, len(inner.CertificateList)) - - for i, entry := range inner.CertificateList { - c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData) - if err != nil { - return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err) - } - - c.CertificateList[i].Extensions = entry.Extensions - } - - return read, nil -} - -// struct { -// SignatureScheme algorithm; -// opaque signature<0..2^16-1>; -// } CertificateVerify; -type CertificateVerifyBody struct { - Algorithm SignatureScheme - Signature []byte `tls:"head=2"` -} - -func (cv CertificateVerifyBody) Type() HandshakeType { - return HandshakeTypeCertificateVerify -} - -func (cv CertificateVerifyBody) Marshal() ([]byte, error) { - return syntax.Marshal(cv) -} - -func (cv *CertificateVerifyBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, cv) -} - -func (cv *CertificateVerifyBody) EncodeSignatureInput(data []byte) []byte { - // TODO: Change context for client auth - // TODO: Put this in a const - const context = "TLS 1.3, server CertificateVerify" - sigInput := bytes.Repeat([]byte{0x20}, 64) - sigInput = append(sigInput, []byte(context)...) - sigInput = append(sigInput, []byte{0}...) - sigInput = append(sigInput, data...) - return sigInput -} - -func (cv *CertificateVerifyBody) Sign(privateKey crypto.Signer, handshakeHash []byte) (err error) { - sigInput := cv.EncodeSignatureInput(handshakeHash) - cv.Signature, err = sign(cv.Algorithm, privateKey, sigInput) - logf(logTypeHandshake, "Signed: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) - return -} - -func (cv *CertificateVerifyBody) Verify(publicKey crypto.PublicKey, handshakeHash []byte) error { - sigInput := cv.EncodeSignatureInput(handshakeHash) - logf(logTypeHandshake, "About to verify: alg=[%04x] sigInput=[%x], sig=[%x]", cv.Algorithm, sigInput, cv.Signature) - return verify(cv.Algorithm, publicKey, sigInput, cv.Signature) -} - -// struct { -// opaque certificate_request_context<0..2^8-1>; -// Extension extensions<2..2^16-1>; -// } CertificateRequest; -type CertificateRequestBody struct { - CertificateRequestContext []byte `tls:"head=1"` - Extensions ExtensionList `tls:"head=2"` -} - -func (cr CertificateRequestBody) Type() HandshakeType { - return HandshakeTypeCertificateRequest -} - -func (cr CertificateRequestBody) Marshal() ([]byte, error) { - return syntax.Marshal(cr) -} - -func (cr *CertificateRequestBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, cr) -} - -// struct { -// uint32 ticket_lifetime; -// uint32 ticket_age_add; -// opaque ticket_nonce<1..255>; -// opaque ticket<1..2^16-1>; -// Extension extensions<0..2^16-2>; -// } NewSessionTicket; -type NewSessionTicketBody struct { - TicketLifetime uint32 - TicketAgeAdd uint32 - TicketNonce []byte `tls:"head=1,min=1"` - Ticket []byte `tls:"head=2,min=1"` - Extensions ExtensionList `tls:"head=2"` -} - -const ticketNonceLen = 16 - -func NewSessionTicket(ticketLen int, ticketLifetime uint32) (*NewSessionTicketBody, error) { - buf := make([]byte, 4+ticketNonceLen+ticketLen) - _, err := prng.Read(buf) - if err != nil { - return nil, err - } - - tkt := &NewSessionTicketBody{ - TicketLifetime: ticketLifetime, - TicketAgeAdd: binary.BigEndian.Uint32(buf[:4]), - TicketNonce: buf[4 : 4+ticketNonceLen], - Ticket: buf[4+ticketNonceLen:], - } - - return tkt, err -} - -func (tkt NewSessionTicketBody) Type() HandshakeType { - return HandshakeTypeNewSessionTicket -} - -func (tkt NewSessionTicketBody) Marshal() ([]byte, error) { - return syntax.Marshal(tkt) -} - -func (tkt *NewSessionTicketBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, tkt) -} - -// enum { -// update_not_requested(0), update_requested(1), (255) -// } KeyUpdateRequest; -// -// struct { -// KeyUpdateRequest request_update; -// } KeyUpdate; -type KeyUpdateBody struct { - KeyUpdateRequest KeyUpdateRequest -} - -func (ku KeyUpdateBody) Type() HandshakeType { - return HandshakeTypeKeyUpdate -} - -func (ku KeyUpdateBody) Marshal() ([]byte, error) { - return syntax.Marshal(ku) -} - -func (ku *KeyUpdateBody) Unmarshal(data []byte) (int, error) { - return syntax.Unmarshal(data, ku) -} - -// struct {} EndOfEarlyData; -type EndOfEarlyDataBody struct{} - -func (eoed EndOfEarlyDataBody) Type() HandshakeType { - return HandshakeTypeEndOfEarlyData -} - -func (eoed EndOfEarlyDataBody) Marshal() ([]byte, error) { - return []byte{}, nil -} - -func (eoed *EndOfEarlyDataBody) Unmarshal(data []byte) (int, error) { - return 0, nil -} diff --git a/vendor/github.com/bifurcation/mint/log.go b/vendor/github.com/bifurcation/mint/log.go deleted file mode 100644 index 2fba90de..00000000 --- a/vendor/github.com/bifurcation/mint/log.go +++ /dev/null @@ -1,55 +0,0 @@ -package mint - -import ( - "fmt" - "log" - "os" - "strings" -) - -// We use this environment variable to control logging. It should be a -// comma-separated list of log tags (see below) or "*" to enable all logging. -const logConfigVar = "MINT_LOG" - -// Pre-defined log types -const ( - logTypeCrypto = "crypto" - logTypeHandshake = "handshake" - logTypeNegotiation = "negotiation" - logTypeIO = "io" - logTypeFrameReader = "frame" - logTypeVerbose = "verbose" -) - -var ( - logFunction = log.Printf - logAll = false - logSettings = map[string]bool{} -) - -func init() { - parseLogEnv(os.Environ()) -} - -func parseLogEnv(env []string) { - for _, stmt := range env { - if strings.HasPrefix(stmt, logConfigVar+"=") { - val := stmt[len(logConfigVar)+1:] - - if val == "*" { - logAll = true - } else { - for _, t := range strings.Split(val, ",") { - logSettings[t] = true - } - } - } - } -} - -func logf(tag string, format string, args ...interface{}) { - if logAll || logSettings[tag] { - fullFormat := fmt.Sprintf("[%s] %s", tag, format) - logFunction(fullFormat, args...) - } -} diff --git a/vendor/github.com/bifurcation/mint/mint.svg b/vendor/github.com/bifurcation/mint/mint.svg deleted file mode 100644 index ae32703d..00000000 --- a/vendor/github.com/bifurcation/mint/mint.svg +++ /dev/null @@ -1,101 +0,0 @@ - - - - - - - - image/svg+xml - - - - - - - - - - - - - - - - diff --git a/vendor/github.com/bifurcation/mint/negotiation.go b/vendor/github.com/bifurcation/mint/negotiation.go deleted file mode 100644 index 2c80b8d7..00000000 --- a/vendor/github.com/bifurcation/mint/negotiation.go +++ /dev/null @@ -1,218 +0,0 @@ -package mint - -import ( - "bytes" - "encoding/hex" - "fmt" - "time" -) - -func VersionNegotiation(offered, supported []uint16) (bool, uint16) { - for _, offeredVersion := range offered { - for _, supportedVersion := range supported { - logf(logTypeHandshake, "[server] version offered by client [%04x] <> [%04x]", offeredVersion, supportedVersion) - if offeredVersion == supportedVersion { - // XXX: Should probably be highest supported version, but for now, we - // only support one version, so it doesn't really matter. - return true, offeredVersion - } - } - } - - return false, 0 -} - -func DHNegotiation(keyShares []KeyShareEntry, groups []NamedGroup) (bool, NamedGroup, []byte, []byte) { - for _, share := range keyShares { - for _, group := range groups { - if group != share.Group { - continue - } - - pub, priv, err := newKeyShare(share.Group) - if err != nil { - // If we encounter an error, just keep looking - continue - } - - dhSecret, err := keyAgreement(share.Group, share.KeyExchange, priv) - if err != nil { - // If we encounter an error, just keep looking - continue - } - - return true, group, pub, dhSecret - } - } - - return false, 0, nil, nil -} - -const ( - ticketAgeTolerance uint32 = 5 * 1000 // five seconds in milliseconds -) - -func PSKNegotiation(identities []PSKIdentity, binders []PSKBinderEntry, context []byte, psks PreSharedKeyCache) (bool, int, *PreSharedKey, CipherSuiteParams, error) { - logf(logTypeNegotiation, "Negotiating PSK offered=[%d] supported=[%d]", len(identities), psks.Size()) - for i, id := range identities { - identityHex := hex.EncodeToString(id.Identity) - - psk, ok := psks.Get(identityHex) - if !ok { - logf(logTypeNegotiation, "No PSK for identity %x", identityHex) - continue - } - - // For resumption, make sure the ticket age is correct - if psk.IsResumption { - extTicketAge := id.ObfuscatedTicketAge - psk.TicketAgeAdd - knownTicketAge := uint32(time.Since(psk.ReceivedAt) / time.Millisecond) - ticketAgeDelta := knownTicketAge - extTicketAge - if knownTicketAge < extTicketAge { - ticketAgeDelta = extTicketAge - knownTicketAge - } - if ticketAgeDelta > ticketAgeTolerance { - logf(logTypeNegotiation, "WARNING potential replay [%x]", psk.Identity) - logf(logTypeNegotiation, "Ticket age exceeds tolerance |%d - %d| = [%d] > [%d]", - extTicketAge, knownTicketAge, ticketAgeDelta, ticketAgeTolerance) - return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("WARNING Potential replay for identity %x", psk.Identity) - } - } - - params, ok := cipherSuiteMap[psk.CipherSuite] - if !ok { - err := fmt.Errorf("tls.cryptoinit: Unsupported ciphersuite from PSK [%04x]", psk.CipherSuite) - return false, 0, nil, CipherSuiteParams{}, err - } - - // Compute binder - binderLabel := labelExternalBinder - if psk.IsResumption { - binderLabel = labelResumptionBinder - } - - h0 := params.Hash.New().Sum(nil) - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - earlySecret := HkdfExtract(params.Hash, zero, psk.Key) - binderKey := deriveSecret(params, earlySecret, binderLabel, h0) - - // context = ClientHello[truncated] - // context = ClientHello1 + HelloRetryRequest + ClientHello2[truncated] - ctxHash := params.Hash.New() - ctxHash.Write(context) - - binder := computeFinishedData(params, binderKey, ctxHash.Sum(nil)) - if !bytes.Equal(binder, binders[i].Binder) { - logf(logTypeNegotiation, "Binder check failed for identity %x; [%x] != [%x]", psk.Identity, binder, binders[i].Binder) - return false, 0, nil, CipherSuiteParams{}, fmt.Errorf("Binder check failed identity %x", psk.Identity) - } - - logf(logTypeNegotiation, "Using PSK with identity %x", psk.Identity) - return true, i, &psk, params, nil - } - - logf(logTypeNegotiation, "Failed to find a usable PSK") - return false, 0, nil, CipherSuiteParams{}, nil -} - -func PSKModeNegotiation(canDoDH, canDoPSK bool, modes []PSKKeyExchangeMode) (bool, bool) { - logf(logTypeNegotiation, "Negotiating PSK modes [%v] [%v] [%+v]", canDoDH, canDoPSK, modes) - dhAllowed := false - dhRequired := true - for _, mode := range modes { - dhAllowed = dhAllowed || (mode == PSKModeDHEKE) - dhRequired = dhRequired && (mode == PSKModeDHEKE) - } - - // Use PSK if we can meet DH requirement and modes were provided - usingPSK := canDoPSK && (!dhRequired || canDoDH) && (len(modes) > 0) - - // Use DH if allowed - usingDH := canDoDH && (dhAllowed || !usingPSK) - - logf(logTypeNegotiation, "Results of PSK mode negotiation: usingDH=[%v] usingPSK=[%v]", usingDH, usingPSK) - return usingDH, usingPSK -} - -func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme, certs []*Certificate) (*Certificate, SignatureScheme, error) { - // Select for server name if provided - candidates := certs - if serverName != nil { - candidatesByName := []*Certificate{} - for _, cert := range certs { - for _, name := range cert.Chain[0].DNSNames { - if len(*serverName) > 0 && name == *serverName { - candidatesByName = append(candidatesByName, cert) - } - } - } - - if len(candidatesByName) == 0 { - return nil, 0, fmt.Errorf("No certificates available for server name: %s", *serverName) - } - - candidates = candidatesByName - } - - // Select for signature scheme - for _, cert := range candidates { - for _, scheme := range signatureSchemes { - if !schemeValidForKey(scheme, cert.PrivateKey) { - continue - } - - return cert, scheme, nil - } - } - - return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") -} - -func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) { - using = gotEarlyData && usingPSK && allowEarlyData - rejected = gotEarlyData && !using - logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected) - return -} - -func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { - for _, s1 := range offered { - if psk != nil { - if s1 == psk.CipherSuite { - return s1, nil - } - continue - } - - for _, s2 := range supported { - if s1 == s2 { - return s1, nil - } - } - } - - return 0, fmt.Errorf("No overlap between offered and supproted ciphersuites (psk? [%v])", psk != nil) -} - -func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, error) { - for _, p1 := range offered { - if psk != nil { - if p1 != psk.NextProto { - continue - } - } - - for _, p2 := range supported { - if p1 == p2 { - return p1, nil - } - } - } - - // If the client offers ALPN on resumption, it must match the earlier one - var err error - if psk != nil && psk.IsResumption && (len(offered) > 0) { - err = fmt.Errorf("ALPN for PSK not provided") - } - return "", err -} diff --git a/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/bifurcation/mint/record-layer.go deleted file mode 100644 index 5cf8ae2c..00000000 --- a/vendor/github.com/bifurcation/mint/record-layer.go +++ /dev/null @@ -1,458 +0,0 @@ -package mint - -import ( - "crypto/cipher" - "fmt" - "io" - "sync" -) - -const ( - sequenceNumberLen = 8 // sequence number length - recordHeaderLenTLS = 5 // record header length (TLS) - recordHeaderLenDTLS = 13 // record header length (DTLS) - maxFragmentLen = 1 << 14 // max number of bytes in a record -) - -type DecryptError string - -func (err DecryptError) Error() string { - return string(err) -} - -type direction uint8 - -const ( - directionWrite = direction(1) - directionRead = direction(2) -) - -// struct { -// ContentType type; -// ProtocolVersion record_version [0301 for CH, 0303 for others] -// uint16 length; -// opaque fragment[TLSPlaintext.length]; -// } TLSPlaintext; -type TLSPlaintext struct { - // Omitted: record_version (static) - // Omitted: length (computed from fragment) - contentType RecordType - epoch Epoch - seq uint64 - fragment []byte -} - -type cipherState struct { - epoch Epoch // DTLS epoch - ivLength int // Length of the seq and nonce fields - seq uint64 // Zero-padded sequence number - iv []byte // Buffer for the IV - cipher cipher.AEAD // AEAD cipher -} - -type RecordLayer struct { - sync.Mutex - label string - direction direction - version uint16 // The current version number - conn io.ReadWriter // The underlying connection - frame *frameReader // The buffered frame reader - nextData []byte // The next record to send - cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" - cachedError error // Error on the last record read - - cipher *cipherState - readCiphers map[Epoch]*cipherState - - datagram bool -} - -type recordLayerFrameDetails struct { - datagram bool -} - -func (d recordLayerFrameDetails) headerLen() int { - if d.datagram { - return recordHeaderLenDTLS - } - return recordHeaderLenTLS -} - -func (d recordLayerFrameDetails) defaultReadLen() int { - return d.headerLen() + maxFragmentLen -} - -func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { - return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil -} - -func newCipherStateNull() *cipherState { - return &cipherState{EpochClear, 0, 0, nil, nil} -} - -func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) { - cipher, err := factory(key) - if err != nil { - return nil, err - } - - return &cipherState{epoch, len(iv), 0, iv, cipher}, nil -} - -func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer { - r := RecordLayer{} - r.label = "" - r.direction = dir - r.conn = conn - r.frame = newFrameReader(recordLayerFrameDetails{false}) - r.cipher = newCipherStateNull() - r.version = tls10Version - return &r -} - -func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer { - r := RecordLayer{} - r.label = "" - r.direction = dir - r.conn = conn - r.frame = newFrameReader(recordLayerFrameDetails{true}) - r.cipher = newCipherStateNull() - r.readCiphers = make(map[Epoch]*cipherState, 0) - r.readCiphers[0] = r.cipher - r.datagram = true - return &r -} - -func (r *RecordLayer) SetVersion(v uint16) { - r.version = v -} - -func (r *RecordLayer) ResetClear(seq uint64) { - r.cipher = newCipherStateNull() - r.cipher.seq = seq -} - -func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error { - cipher, err := newCipherStateAead(epoch, factory, key, iv) - if err != nil { - return err - } - r.cipher = cipher - if r.datagram && r.direction == directionRead { - r.readCiphers[epoch] = cipher - } - return nil -} - -// TODO(ekr@rtfm.com): This is never used, which is a bug. -func (r *RecordLayer) DiscardReadKey(epoch Epoch) { - if !r.datagram { - return - } - - _, ok := r.readCiphers[epoch] - assert(ok) - delete(r.readCiphers, epoch) -} - -func (c *cipherState) combineSeq(datagram bool) uint64 { - seq := c.seq - if datagram { - seq |= uint64(c.epoch) << 48 - } - return seq -} - -func (c *cipherState) computeNonce(seq uint64) []byte { - nonce := make([]byte, len(c.iv)) - copy(nonce, c.iv) - - s := seq - - offset := len(c.iv) - for i := 0; i < 8; i++ { - nonce[(offset-i)-1] ^= byte(s & 0xff) - s >>= 8 - } - logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce) - - return nonce -} - -func (c *cipherState) incrementSequenceNumber() { - if c.seq >= (1<<48 - 1) { - // Not allowed to let sequence number wrap. - // Instead, must renegotiate before it does. - // Not likely enough to bother. This is the - // DTLS limit. - panic("TLS: sequence number wraparound") - } - c.seq++ -} - -func (c *cipherState) overhead() int { - if c.cipher == nil { - return 0 - } - return c.cipher.Overhead() -} - -func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext { - assert(r.direction == directionWrite) - logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq) - // Expand the fragment to hold contentType, padding, and overhead - originalLen := len(pt.fragment) - plaintextLen := originalLen + 1 + padLen - ciphertextLen := plaintextLen + cipher.overhead() - - // Assemble the revised plaintext - out := &TLSPlaintext{ - - contentType: RecordTypeApplicationData, - fragment: make([]byte, ciphertextLen), - } - copy(out.fragment, pt.fragment) - out.fragment[originalLen] = byte(pt.contentType) - for i := 1; i <= padLen; i++ { - out.fragment[originalLen+i] = 0 - } - - // Encrypt the fragment - payload := out.fragment[:plaintextLen] - cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil) - return out -} - -func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) { - assert(r.direction == directionRead) - logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq) - if len(pt.fragment) < r.cipher.overhead() { - msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead()) - return nil, 0, DecryptError(msg) - } - - decryptLen := len(pt.fragment) - r.cipher.overhead() - out := &TLSPlaintext{ - contentType: pt.contentType, - fragment: make([]byte, decryptLen), - } - - // Decrypt - _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil) - if err != nil { - logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt) - return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") - } - - // Find the padding boundary - padLen := 0 - for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ { - } - - // Transfer the content type - newLen := decryptLen - padLen - 1 - out.contentType = RecordType(out.fragment[newLen]) - - // Truncate the message to remove contentType, padding, overhead - out.fragment = out.fragment[:newLen] - out.seq = seq - return out, padLen, nil -} - -func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { - var pt *TLSPlaintext - var err error - - for { - pt, err = r.nextRecord(false) - if err == nil { - break - } - if !block || err != AlertWouldBlock { - return 0, err - } - } - return pt.contentType, nil -} - -func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { - pt, err := r.nextRecord(false) - - // Consume the cached record if there was one - r.cachedRecord = nil - r.cachedError = nil - - return pt, err -} - -func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) { - pt, err := r.nextRecord(true) - - // Consume the cached record if there was one - r.cachedRecord = nil - r.cachedError = nil - - return pt, err -} - -func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) { - cipher := r.cipher - if r.cachedRecord != nil { - logf(logTypeIO, "%s Returning cached record", r.label) - return r.cachedRecord, r.cachedError - } - - // Loop until one of three things happens: - // - // 1. We get a frame - // 2. We try to read off the socket and get nothing, in which case - // returnAlertWouldBlock - // 3. We get an error. - var err error - err = AlertWouldBlock - var header, body []byte - - for err != nil { - if r.frame.needed() > 0 { - buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) - n, err := r.conn.Read(buf) - if err != nil { - logf(logTypeIO, "%s Error reading, %v", r.label, err) - return nil, err - } - - if n == 0 { - return nil, AlertWouldBlock - } - - logf(logTypeIO, "%s Read %v bytes", r.label, n) - - buf = buf[:n] - r.frame.addChunk(buf) - } - - header, body, err = r.frame.process() - // Loop around onAlertWouldBlock to see if some - // data is now available. - if err != nil && err != AlertWouldBlock { - return nil, err - } - } - - pt := &TLSPlaintext{} - // Validate content type - switch RecordType(header[0]) { - default: - return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) - case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck: - pt.contentType = RecordType(header[0]) - } - - // Validate version - if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) { - return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2]) - } - - // Validate size < max - size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1]) - - if size > maxFragmentLen+256 { - return nil, fmt.Errorf("tls.record: Ciphertext size too big") - } - - pt.fragment = make([]byte, size) - copy(pt.fragment, body) - - // TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data. - - // Attempt to decrypt fragment - seq := cipher.seq - if r.datagram { - // TODO(ekr@rtfm.com): Handle duplicates. - seq, _ = decodeUint(header[3:11], 8) - epoch := Epoch(seq >> 48) - - // Look up the cipher suite from the epoch - c, ok := r.readCiphers[epoch] - if !ok { - logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch) - return nil, AlertWouldBlock - } - - if epoch != cipher.epoch { - logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch, - cipher.epoch, allowOldEpoch) - if !allowOldEpoch { - return nil, AlertWouldBlock - } - cipher = c - } - } - - if cipher.cipher != nil { - logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment) - pt, _, err = r.decrypt(pt, seq) - if err != nil { - logf(logTypeIO, "%s Decryption failed", r.label) - return nil, err - } - } - pt.epoch = cipher.epoch - - // Check that plaintext length is not too long - if len(pt.fragment) > maxFragmentLen { - return nil, fmt.Errorf("tls.record: Plaintext size too big") - } - - logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment) - - r.cachedRecord = pt - cipher.incrementSequenceNumber() - return pt, nil -} - -func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { - return r.writeRecordWithPadding(pt, r.cipher, 0) -} - -func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { - return r.writeRecordWithPadding(pt, r.cipher, padLen) -} - -func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error { - seq := cipher.combineSeq(r.datagram) - if cipher.cipher != nil { - logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) - pt = r.encrypt(cipher, seq, pt, padLen) - } else if padLen > 0 { - return fmt.Errorf("tls.record: Padding can only be done on encrypted records") - } - - if len(pt.fragment) > maxFragmentLen { - return fmt.Errorf("tls.record: Record size too big") - } - - length := len(pt.fragment) - var header []byte - - if !r.datagram { - header = []byte{byte(pt.contentType), - byte(r.version >> 8), byte(r.version & 0xff), - byte(length >> 8), byte(length)} - } else { - header = make([]byte, 13) - version := dtlsConvertVersion(r.version) - copy(header, []byte{byte(pt.contentType), - byte(version >> 8), byte(version & 0xff), - }) - encodeUint(seq, 8, header[3:]) - encodeUint(uint64(length), 2, header[11:]) - } - record := append(header, pt.fragment...) - - logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) - - cipher.incrementSequenceNumber() - _, err := r.conn.Write(record) - return err -} diff --git a/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/bifurcation/mint/server-state-machine.go deleted file mode 100644 index f91b22e4..00000000 --- a/vendor/github.com/bifurcation/mint/server-state-machine.go +++ /dev/null @@ -1,1177 +0,0 @@ -package mint - -import ( - "bytes" - "crypto/x509" - "fmt" - "hash" - "reflect" - - "github.com/bifurcation/mint/syntax" -) - -// Server State Machine -// -// START <-----+ -// Recv ClientHello | | Send HelloRetryRequest -// v | -// RECVD_CH ----+ -// | Select parameters -// | Send ServerHello -// v -// NEGOTIATED -// | Send EncryptedExtensions -// | [Send CertificateRequest] -// Can send | [Send Certificate + CertificateVerify] -// app data --> | Send Finished -// after here | -// +-----------+--------+ -// | | | -// Rejected 0-RTT | No | | 0-RTT -// | 0-RTT | | -// | | v -// +---->READ_PAST | WAIT_EOED <---+ -// Decrypt | | | Decrypt | Recv | | | Recv -// error | | | OK + HS | EOED | | | early data -// +-----+ | V | +-----+ -// +---> WAIT_FLIGHT2 <-+ -// | -// +--------+--------+ -// No auth | | Client auth -// | | -// | v -// | WAIT_CERT -// | Recv | | Recv Certificate -// | empty | v -// | Certificate | WAIT_CV -// | | | Recv -// | v | CertificateVerify -// +-> WAIT_FINISHED <---+ -// | Recv Finished -// v -// CONNECTED -// -// NB: Not using state RECVD_CH -// -// State Instructions -// START {} -// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] -// WAIT_EOED RekeyIn; -// READ_PAST {} -// WAIT_FLIGHT2 {} -// WAIT_CERT_CR {} -// WAIT_CERT {} -// WAIT_CV {} -// WAIT_FINISHED RekeyIn; RekeyOut; -// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) - -// A cookie can be sent to the client in a HRR. -type cookie struct { - // The CipherSuite that was selected when the client sent the first ClientHello - CipherSuite CipherSuite - ClientHelloHash []byte `tls:"head=2"` - - // The ApplicationCookie can be provided by the application (by setting a Config.CookieHandler) - ApplicationCookie []byte `tls:"head=2"` -} - -type serverStateStart struct { - Config *Config - conn *Conn - hsCtx *HandshakeContext -} - -var _ HandshakeState = &serverStateStart{} - -func (state serverStateStart) State() State { - return StateServerStart -} - -func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeClientHello { - logf(logTypeHandshake, "[ServerStateStart] unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - ch := &ClientHelloBody{LegacyVersion: wireVersion(state.hsCtx.hIn)} - if err := safeUnmarshal(ch, hm.body); err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - // We are strict about these things because we only support 1.3 - if ch.LegacyVersion != wireVersion(state.hsCtx.hIn) { - logf(logTypeHandshake, "[ServerStateStart] Invalid version number: %v", ch.LegacyVersion) - return nil, nil, AlertDecodeError - } - - clientHello := hm - connParams := ConnectionParameters{} - - supportedVersions := &SupportedVersionsExtension{HandshakeType: HandshakeTypeClientHello} - serverName := new(ServerNameExtension) - supportedGroups := new(SupportedGroupsExtension) - signatureAlgorithms := new(SignatureAlgorithmsExtension) - clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello} - clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello} - clientEarlyData := &EarlyDataExtension{} - clientALPN := new(ALPNExtension) - clientPSKModes := new(PSKKeyExchangeModesExtension) - clientCookie := new(CookieExtension) - - // Handle external extensions. - if state.Config.ExtensionHandler != nil { - err := state.Config.ExtensionHandler.Receive(HandshakeTypeClientHello, &ch.Extensions) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error running external extension handler [%v]", err) - return nil, nil, AlertInternalError - } - } - - foundExts, err := ch.Extensions.Parse( - []ExtensionBody{ - supportedVersions, - serverName, - supportedGroups, - signatureAlgorithms, - clientEarlyData, - clientKeyShares, - clientPSK, - clientALPN, - clientPSKModes, - clientCookie, - }) - - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error parsing extensions [%v]", err) - return nil, nil, AlertDecodeError - } - - clientSentCookie := len(clientCookie.Cookie) > 0 - - if foundExts[ExtensionTypeServerName] { - connParams.ServerName = string(*serverName) - } - - // If the client didn't send supportedVersions or doesn't support 1.3, - // then we're done here. - if !foundExts[ExtensionTypeSupportedVersions] { - logf(logTypeHandshake, "[ServerStateStart] Client did not send supported_versions") - return nil, nil, AlertProtocolVersion - } - versionOK, _ := VersionNegotiation(supportedVersions.Versions, []uint16{supportedVersion}) - if !versionOK { - logf(logTypeHandshake, "[ServerStateStart] Client does not support the same version") - return nil, nil, AlertProtocolVersion - } - - // The client sent a cookie. So this is probably the second ClientHello (sent as a response to a HRR) - var firstClientHello *HandshakeMessage - var initialCipherSuite CipherSuiteParams // the cipher suite that was negotiated when sending the HelloRetryRequest - if clientSentCookie { - plainCookie, err := state.Config.CookieProtector.DecodeToken(clientCookie.Cookie) - if err != nil { - logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error decoding token [%v]", err)) - return nil, nil, AlertDecryptError - } - cookie := &cookie{} - if rb, err := syntax.Unmarshal(plainCookie, cookie); err != nil && rb != len(plainCookie) { // this should never happen - logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Error unmarshaling cookie [%v]", err)) - return nil, nil, AlertInternalError - } - // restore the hash of initial ClientHello from the cookie - firstClientHello = &HandshakeMessage{ - msgType: HandshakeTypeMessageHash, - body: cookie.ClientHelloHash, - } - // have the application validate its part of the cookie - if state.Config.CookieHandler != nil && !state.Config.CookieHandler.Validate(state.conn, cookie.ApplicationCookie) { - logf(logTypeHandshake, "[ServerStateStart] Cookie mismatch") - return nil, nil, AlertAccessDenied - } - var ok bool - initialCipherSuite, ok = cipherSuiteMap[cookie.CipherSuite] - if !ok { - logf(logTypeHandshake, fmt.Sprintf("[ServerStateStart] Cookie contained invalid cipher suite: %#x", cookie.CipherSuite)) - return nil, nil, AlertInternalError - } - } - - if len(ch.LegacySessionID) != 0 && len(ch.LegacySessionID) != 32 { - logf(logTypeHandshake, "[ServerStateStart] invalid session ID") - return nil, nil, AlertIllegalParameter - } - - // Figure out if we can do DH - canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Config.Groups) - - // Figure out if we can do PSK - var canDoPSK bool - var selectedPSK int - var params CipherSuiteParams - var psk *PreSharedKey - if len(clientPSK.Identities) > 0 { - contextBase := []byte{} - if clientSentCookie { - contextBase = append(contextBase, firstClientHello.Marshal()...) - // fill in the cookie sent by the client. Needed to calculate the correct hash - cookieExt := &CookieExtension{Cookie: clientCookie.Cookie} - hrr, err := state.generateHRR(params.Suite, - ch.LegacySessionID, cookieExt) - if err != nil { - return nil, nil, AlertInternalError - } - contextBase = append(contextBase, hrr.Marshal()...) - } - chTrunc, err := ch.Truncated() - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error computing truncated ClientHello [%v]", err) - return nil, nil, AlertDecodeError - } - context := append(contextBase, chTrunc...) - - canDoPSK, selectedPSK, psk, params, err = PSKNegotiation(clientPSK.Identities, clientPSK.Binders, context, state.Config.PSKs) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Figure out if we actually should do DH / PSK - connParams.UsingDH, connParams.UsingPSK = PSKModeNegotiation(canDoDH, canDoPSK, clientPSKModes.KEModes) - - // Select a ciphersuite - connParams.CipherSuite, err = CipherSuiteNegotiation(psk, ch.CipherSuites, state.Config.CipherSuites) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] No common ciphersuite found [%v]", err) - return nil, nil, AlertHandshakeFailure - } - if clientSentCookie && initialCipherSuite.Suite != connParams.CipherSuite { - logf(logTypeHandshake, "[ServerStateStart] Would have selected a different CipherSuite after receiving the client's Cookie") - return nil, nil, AlertInternalError - } - - var helloRetryRequest *HandshakeMessage - if state.Config.RequireCookie { - // Send a cookie if required - // NB: Need to do this here because it's after ciphersuite selection, which - // has to be after PSK selection. - var shouldSendHRR bool - var cookieExt *CookieExtension - if !clientSentCookie { // this is the first ClientHello that we receive - var appCookie []byte - if state.Config.CookieHandler == nil { // if Config.RequireCookie is set, but no CookieHandler was provided, we definitely need to send a cookie - shouldSendHRR = true - } else { // if the CookieHandler was set, we just send a cookie when the application provides one - var err error - appCookie, err = state.Config.CookieHandler.Generate(state.conn) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error generating cookie [%v]", err) - return nil, nil, AlertInternalError - } - shouldSendHRR = appCookie != nil - } - if shouldSendHRR { - params := cipherSuiteMap[connParams.CipherSuite] - h := params.Hash.New() - h.Write(clientHello.Marshal()) - plainCookie, err := syntax.Marshal(cookie{ - CipherSuite: connParams.CipherSuite, - ClientHelloHash: h.Sum(nil), - ApplicationCookie: appCookie, - }) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error marshalling cookie [%v]", err) - return nil, nil, AlertInternalError - } - cookieData, err := state.Config.CookieProtector.NewToken(plainCookie) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error encoding cookie [%v]", err) - return nil, nil, AlertInternalError - } - cookieExt = &CookieExtension{Cookie: cookieData} - } - } else { - cookieExt = &CookieExtension{Cookie: clientCookie.Cookie} - } - - // Generate a HRR. We will need it in both of the two cases: - // 1. We need to send a Cookie. Then this HRR will be sent on the wire - // 2. We need to validate a cookie. Then we need its hash - // Ignoring errors because everything here is newly constructed, so there - // shouldn't be marshal errors - if shouldSendHRR || clientSentCookie { - helloRetryRequest, err = state.generateHRR(connParams.CipherSuite, - ch.LegacySessionID, cookieExt) - if err != nil { - return nil, nil, AlertInternalError - } - } - - if shouldSendHRR { - toSend := []HandshakeAction{ - QueueHandshakeMessage{helloRetryRequest}, - SendQueuedHandshake{}, - } - logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateStart]") - return state, toSend, AlertStatelessRetry - } - } - - // If we've got no entropy to make keys from, fail - if !connParams.UsingDH && !connParams.UsingPSK { - logf(logTypeHandshake, "[ServerStateStart] Neither DH nor PSK negotiated") - return nil, nil, AlertHandshakeFailure - } - - var pskSecret []byte - var cert *Certificate - var certScheme SignatureScheme - if connParams.UsingPSK { - pskSecret = psk.Key - } else { - psk = nil - - // If we're not using a PSK mode, then we need to have certain extensions - if !(foundExts[ExtensionTypeServerName] && - foundExts[ExtensionTypeSupportedGroups] && - foundExts[ExtensionTypeSignatureAlgorithms]) { - logf(logTypeHandshake, "[ServerStateStart] Insufficient extensions (%v)", foundExts) - return nil, nil, AlertMissingExtension - } - - // Select a certificate - name := string(*serverName) - var err error - cert, certScheme, err = CertificateSelection(&name, signatureAlgorithms.Algorithms, state.Config.Certificates) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] No appropriate certificate found [%v]", err) - return nil, nil, AlertAccessDenied - } - } - - if !connParams.UsingDH { - dhSecret = nil - } - - // Figure out if we're going to do early data - var clientEarlyTrafficSecret []byte - connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData] - connParams.UsingEarlyData, connParams.RejectedEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) - if connParams.UsingEarlyData { - h := params.Hash.New() - h.Write(clientHello.Marshal()) - chHash := h.Sum(nil) - - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - earlySecret := HkdfExtract(params.Hash, zero, pskSecret) - clientEarlyTrafficSecret = deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) - } - - // Select a next protocol - connParams.NextProto, err = ALPNNegotiation(psk, clientALPN.Protocols, state.Config.NextProtos) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] No common application-layer protocol found [%v]", err) - return nil, nil, AlertNoApplicationProtocol - } - - state.hsCtx.receivedEndOfFlight() - - logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") - state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. - return serverStateNegotiated{ - Config: state.Config, - Params: connParams, - hsCtx: state.hsCtx, - dhGroup: dhGroup, - dhPublic: dhPublic, - dhSecret: dhSecret, - pskSecret: pskSecret, - selectedPSK: selectedPSK, - cert: cert, - certScheme: certScheme, - legacySessionId: ch.LegacySessionID, - clientEarlyTrafficSecret: clientEarlyTrafficSecret, - - firstClientHello: firstClientHello, - helloRetryRequest: helloRetryRequest, - clientHello: clientHello, - }, nil, AlertNoAlert -} - -func (state *serverStateStart) generateHRR(cs CipherSuite, legacySessionId []byte, - cookieExt *CookieExtension) (*HandshakeMessage, error) { - var helloRetryRequest *HandshakeMessage - hrr := &ServerHelloBody{ - Version: tls12Version, - Random: hrrRandomSentinel, - CipherSuite: cs, - LegacySessionID: legacySessionId, - LegacyCompressionMethod: 0, - } - - sv := &SupportedVersionsExtension{ - HandshakeType: HandshakeTypeServerHello, - Versions: []uint16{supportedVersion}, - } - - if err := hrr.Extensions.Add(sv); err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error adding SupportedVersion [%v]", err) - return nil, err - } - - if err := hrr.Extensions.Add(cookieExt); err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error adding CookieExtension [%v]", err) - return nil, err - } - // Run the external extension handler. - if state.Config.ExtensionHandler != nil { - err := state.Config.ExtensionHandler.Send(HandshakeTypeHelloRetryRequest, &hrr.Extensions) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error running external extension sender [%v]", err) - return nil, err - } - } - helloRetryRequest, err := state.hsCtx.hOut.HandshakeMessageFromBody(hrr) - if err != nil { - logf(logTypeHandshake, "[ServerStateStart] Error marshaling HRR [%v]", err) - return nil, err - } - return helloRetryRequest, nil -} - -type serverStateNegotiated struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - dhGroup NamedGroup - dhPublic []byte - dhSecret []byte - pskSecret []byte - clientEarlyTrafficSecret []byte - selectedPSK int - cert *Certificate - certScheme SignatureScheme - legacySessionId []byte - firstClientHello *HandshakeMessage - helloRetryRequest *HandshakeMessage - clientHello *HandshakeMessage -} - -var _ HandshakeState = &serverStateNegotiated{} - -func (state serverStateNegotiated) State() State { - return StateServerNegotiated -} - -func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - // Create the ServerHello - sh := &ServerHelloBody{ - Version: tls12Version, - CipherSuite: state.Params.CipherSuite, - LegacySessionID: state.legacySessionId, - LegacyCompressionMethod: 0, - } - if _, err := prng.Read(sh.Random[:]); err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error creating server random [%v]", err) - return nil, nil, AlertInternalError - } - - err := sh.Extensions.Add(&SupportedVersionsExtension{ - HandshakeType: HandshakeTypeServerHello, - Versions: []uint16{supportedVersion}, - }) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported_versions extension [%v]", err) - return nil, nil, AlertInternalError - } - if state.Params.UsingDH { - logf(logTypeHandshake, "[ServerStateNegotiated] sending DH extension") - err := sh.Extensions.Add(&KeyShareExtension{ - HandshakeType: HandshakeTypeServerHello, - Shares: []KeyShareEntry{{Group: state.dhGroup, KeyExchange: state.dhPublic}}, - }) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding key_shares extension [%v]", err) - return nil, nil, AlertInternalError - } - } - if state.Params.UsingPSK { - logf(logTypeHandshake, "[ServerStateNegotiated] sending PSK extension") - err := sh.Extensions.Add(&PreSharedKeyExtension{ - HandshakeType: HandshakeTypeServerHello, - SelectedIdentity: uint16(state.selectedPSK), - }) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding PSK extension [%v]", err) - return nil, nil, AlertInternalError - } - } - - // Run the external extension handler. - if state.Config.ExtensionHandler != nil { - err := state.Config.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) - return nil, nil, AlertInternalError - } - } - - serverHello, err := state.hsCtx.hOut.HandshakeMessageFromBody(sh) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling ServerHello [%v]", err) - return nil, nil, AlertInternalError - } - - // Look up crypto params - params, ok := cipherSuiteMap[sh.CipherSuite] - if !ok { - logf(logTypeCrypto, "Unsupported ciphersuite [%04x]", sh.CipherSuite) - return nil, nil, AlertHandshakeFailure - } - - // Start up the handshake hash - handshakeHash := params.Hash.New() - handshakeHash.Write(state.firstClientHello.Marshal()) - handshakeHash.Write(state.helloRetryRequest.Marshal()) - handshakeHash.Write(state.clientHello.Marshal()) - handshakeHash.Write(serverHello.Marshal()) - - // Compute handshake secrets - zero := bytes.Repeat([]byte{0}, params.Hash.Size()) - - var earlySecret []byte - if state.Params.UsingPSK { - earlySecret = HkdfExtract(params.Hash, zero, state.pskSecret) - } else { - earlySecret = HkdfExtract(params.Hash, zero, zero) - } - - if state.dhSecret == nil { - state.dhSecret = zero - } - - h0 := params.Hash.New().Sum(nil) - h2 := handshakeHash.Sum(nil) - preHandshakeSecret := deriveSecret(params, earlySecret, labelDerived, h0) - handshakeSecret := HkdfExtract(params.Hash, preHandshakeSecret, state.dhSecret) - clientHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelClientHandshakeTrafficSecret, h2) - serverHandshakeTrafficSecret := deriveSecret(params, handshakeSecret, labelServerHandshakeTrafficSecret, h2) - preMasterSecret := deriveSecret(params, handshakeSecret, labelDerived, h0) - masterSecret := HkdfExtract(params.Hash, preMasterSecret, zero) - - logf(logTypeCrypto, "early secret (init!): [%d] %x", len(earlySecret), earlySecret) - logf(logTypeCrypto, "handshake secret: [%d] %x", len(handshakeSecret), handshakeSecret) - logf(logTypeCrypto, "client handshake traffic secret: [%d] %x", len(clientHandshakeTrafficSecret), clientHandshakeTrafficSecret) - logf(logTypeCrypto, "server handshake traffic secret: [%d] %x", len(serverHandshakeTrafficSecret), serverHandshakeTrafficSecret) - logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) - - clientHandshakeKeys := makeTrafficKeys(params, clientHandshakeTrafficSecret) - serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - - // Send an EncryptedExtensions message (even if it's empty) - eeList := ExtensionList{} - if state.Params.NextProto != "" { - logf(logTypeHandshake, "[server] sending ALPN extension") - err = eeList.Add(&ALPNExtension{Protocols: []string{state.Params.NextProto}}) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding ALPN to EncryptedExtensions [%v]", err) - return nil, nil, AlertInternalError - } - } - if state.Params.UsingEarlyData { - logf(logTypeHandshake, "[server] sending EDI extension") - err = eeList.Add(&EarlyDataExtension{}) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding EDI to EncryptedExtensions [%v]", err) - return nil, nil, AlertInternalError - } - } - ee := &EncryptedExtensionsBody{eeList} - - // Run the external extension handler. - if state.Config.ExtensionHandler != nil { - err := state.Config.ExtensionHandler.Send(HandshakeTypeEncryptedExtensions, &ee.Extensions) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error running external extension sender [%v]", err) - return nil, nil, AlertInternalError - } - } - - eem, err := state.hsCtx.hOut.HandshakeMessageFromBody(ee) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling EncryptedExtensions [%v]", err) - return nil, nil, AlertInternalError - } - - handshakeHash.Write(eem.Marshal()) - - toSend := []HandshakeAction{ - QueueHandshakeMessage{serverHello}, - RekeyOut{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, - QueueHandshakeMessage{eem}, - } - - // Authenticate with a certificate if required - if !state.Params.UsingPSK { - // Send a CertificateRequest message if we want client auth - if state.Config.RequireClientAuth { - state.Params.UsingClientAuth = true - - // XXX: We don't support sending any constraints besides a list of - // supported signature algorithms - cr := &CertificateRequestBody{} - schemes := &SignatureAlgorithmsExtension{Algorithms: state.Config.SignatureSchemes} - err := cr.Extensions.Add(schemes) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error adding supported schemes to CertificateRequest [%v]", err) - return nil, nil, AlertInternalError - } - - crm, err := state.hsCtx.hOut.HandshakeMessageFromBody(cr) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateRequest [%v]", err) - return nil, nil, AlertInternalError - } - //TODO state.state.serverCertificateRequest = cr - - toSend = append(toSend, QueueHandshakeMessage{crm}) - handshakeHash.Write(crm.Marshal()) - } - - // Create and send Certificate, CertificateVerify - certificate := &CertificateBody{ - CertificateList: make([]CertificateEntry, len(state.cert.Chain)), - } - for i, entry := range state.cert.Chain { - certificate.CertificateList[i] = CertificateEntry{CertData: entry} - } - certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, QueueHandshakeMessage{certm}) - handshakeHash.Write(certm.Marshal()) - - certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} - logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash) - - hcv := handshakeHash.Sum(nil) - logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - - err = certificateVerify.Sign(state.cert.PrivateKey, hcv) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error signing CertificateVerify [%v]", err) - return nil, nil, AlertInternalError - } - certvm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificateVerify) - if err != nil { - logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling CertificateVerify [%v]", err) - return nil, nil, AlertInternalError - } - - toSend = append(toSend, QueueHandshakeMessage{certvm}) - handshakeHash.Write(certvm.Marshal()) - } - - // Compute secrets resulting from the server's first flight - h3 := handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 3 [%d] %x", len(h3), h3) - logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h3), h3) - - serverFinishedData := computeFinishedData(params, serverHandshakeTrafficSecret, h3) - logf(logTypeCrypto, "server finished data: [%d] %x", len(serverFinishedData), serverFinishedData) - - // Assemble the Finished message - fin := &FinishedBody{ - VerifyDataLen: len(serverFinishedData), - VerifyData: serverFinishedData, - } - finm, _ := state.hsCtx.hOut.HandshakeMessageFromBody(fin) - - toSend = append(toSend, QueueHandshakeMessage{finm}) - handshakeHash.Write(finm.Marshal()) - toSend = append(toSend, SendQueuedHandshake{}) - - // Compute traffic secrets - h4 := handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 4 [%d] %x", len(h4), h4) - logf(logTypeCrypto, "handshake hash for server Finished: [%d] %x", len(h4), h4) - - clientTrafficSecret := deriveSecret(params, masterSecret, labelClientApplicationTrafficSecret, h4) - serverTrafficSecret := deriveSecret(params, masterSecret, labelServerApplicationTrafficSecret, h4) - logf(logTypeCrypto, "client traffic secret: [%d] %x", len(clientTrafficSecret), clientTrafficSecret) - logf(logTypeCrypto, "server traffic secret: [%d] %x", len(serverTrafficSecret), serverTrafficSecret) - - serverTrafficKeys := makeTrafficKeys(params, serverTrafficSecret) - toSend = append(toSend, RekeyOut{epoch: EpochApplicationData, KeySet: serverTrafficKeys}) - - exporterSecret := deriveSecret(params, masterSecret, labelExporterSecret, h4) - logf(logTypeCrypto, "server exporter secret: [%d] %x", len(exporterSecret), exporterSecret) - - if state.Params.UsingEarlyData { - clientEarlyTrafficKeys := makeTrafficKeys(params, state.clientEarlyTrafficSecret) - - logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitEOED]") - nextState := serverStateWaitEOED{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: params, - handshakeHash: handshakeHash, - masterSecret: masterSecret, - clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, - clientTrafficSecret: clientTrafficSecret, - serverTrafficSecret: serverTrafficSecret, - exporterSecret: exporterSecret, - } - toSend = append(toSend, []HandshakeAction{ - RekeyIn{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, - }...) - return nextState, toSend, AlertNoAlert - } - - logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") - toSend = append(toSend, []HandshakeAction{ - RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, - }...) - var nextState HandshakeState - nextState = serverStateWaitFlight2{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: params, - handshakeHash: handshakeHash, - masterSecret: masterSecret, - clientHandshakeTrafficSecret: clientHandshakeTrafficSecret, - clientTrafficSecret: clientTrafficSecret, - serverTrafficSecret: serverTrafficSecret, - exporterSecret: exporterSecret, - } - if state.Params.RejectedEarlyData { - nextState = serverStateReadPastEarlyData{ - hsCtx: state.hsCtx, - next: &nextState, - } - } - return nextState, toSend, AlertNoAlert -} - -type serverStateWaitEOED struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - masterSecret []byte - clientHandshakeTrafficSecret []byte - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -var _ HandshakeState = &serverStateWaitEOED{} - -func (state serverStateWaitEOED) State() State { - return StateServerWaitEOED -} - -func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - for { - logf(logTypeHandshake, "Server reading early data...") - assert(state.hsCtx.hIn.conn.cipher.epoch == EpochEarlyData) - t, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) - if err == AlertWouldBlock { - return nil, nil, AlertWouldBlock - } - - if err != nil { - logf(logTypeHandshake, "Server Error reading record type (1): %v", err) - return nil, nil, AlertBadRecordMAC - } - - logf(logTypeHandshake, "Server got record type(1): %v", t) - - if t != RecordTypeApplicationData { - break - } - - // Read a record into the buffer. Note that this is safe - // in blocking mode because we read the record in - // PeekRecordType. - pt, err := state.hsCtx.hIn.conn.ReadRecord() - if err != nil { - logf(logTypeHandshake, "Server error reading early data record: %v", err) - return nil, nil, AlertInternalError - } - - logf(logTypeHandshake, "Server read early data: %x", pt.fragment) - state.hsCtx.earlyData = append(state.hsCtx.earlyData, pt.fragment...) - } - - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeEndOfEarlyData { - logf(logTypeHandshake, "[ServerStateWaitEOED] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - if len(hm.body) > 0 { - logf(logTypeHandshake, "[ServerStateWaitEOED] Error decoding message [len > 0]") - return nil, nil, AlertDecodeError - } - - state.handshakeHash.Write(hm.Marshal()) - - clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) - - logf(logTypeHandshake, "[ServerStateWaitEOED] -> [ServerStateWaitFlight2]") - toSend := []HandshakeAction{ - RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, - } - waitFlight2 := serverStateWaitFlight2{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - return waitFlight2, toSend, AlertNoAlert -} - -var _ HandshakeState = &serverStateReadPastEarlyData{} - -type serverStateReadPastEarlyData struct { - hsCtx *HandshakeContext - next *HandshakeState -} - -func (state serverStateReadPastEarlyData) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - for { - logf(logTypeHandshake, "Server reading past early data...") - // Scan past all records that fail to decrypt - _, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) - if err == nil { - break - } - - if err == AlertWouldBlock { - return nil, nil, AlertWouldBlock - } - - // Continue on DecryptError - _, ok := err.(DecryptError) - if !ok { - return nil, nil, AlertInternalError // Really need something else. - } - } - - return *state.next, nil, AlertNoAlert -} - -func (state serverStateReadPastEarlyData) State() State { - return StateServerReadPastEarlyData -} - -type serverStateWaitFlight2 struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - masterSecret []byte - clientHandshakeTrafficSecret []byte - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -var _ HandshakeState = &serverStateWaitFlight2{} - -func (state serverStateWaitFlight2) State() State { - return StateServerWaitFlight2 -} - -func (state serverStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - if state.Params.UsingClientAuth { - logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitCert]") - nextState := serverStateWaitCert{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - handshakeHash: state.handshakeHash, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert - } - - logf(logTypeHandshake, "[ServerStateWaitFlight2] -> [ServerStateWaitFinished]") - nextState := serverStateWaitFinished{ - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - handshakeHash: state.handshakeHash, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert -} - -type serverStateWaitCert struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - masterSecret []byte - clientHandshakeTrafficSecret []byte - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -var _ HandshakeState = &serverStateWaitCert{} - -func (state serverStateWaitCert) State() State { - return StateServerWaitCert -} - -func (state serverStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeCertificate { - logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - cert := &CertificateBody{} - if err := safeUnmarshal(cert, hm.body); err != nil { - logf(logTypeHandshake, "[ServerStateWaitCert] Unexpected message") - return nil, nil, AlertDecodeError - } - - state.handshakeHash.Write(hm.Marshal()) - - if len(cert.CertificateList) == 0 { - logf(logTypeHandshake, "[ServerStateWaitCert] WARNING client did not provide a certificate") - - logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitFinished]") - nextState := serverStateWaitFinished{ - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - handshakeHash: state.handshakeHash, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert - } - - logf(logTypeHandshake, "[ServerStateWaitCert] -> [ServerStateWaitCV]") - nextState := serverStateWaitCV{ - Config: state.Config, - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - handshakeHash: state.handshakeHash, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - clientCertificate: cert, - exporterSecret: state.exporterSecret, - } - return nextState, nil, AlertNoAlert -} - -type serverStateWaitCV struct { - Config *Config - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - - masterSecret []byte - clientHandshakeTrafficSecret []byte - - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte - - clientCertificate *CertificateBody -} - -var _ HandshakeState = &serverStateWaitCV{} - -func (state serverStateWaitCV) State() State { - return StateServerWaitCV -} - -func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeCertificateVerify { - logf(logTypeHandshake, "[ServerStateWaitCV] Unexpected message [%+v] [%s]", hm, reflect.TypeOf(hm)) - return nil, nil, AlertUnexpectedMessage - } - - certVerify := &CertificateVerifyBody{} - if err := safeUnmarshal(certVerify, hm.body); err != nil { - logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) - return nil, nil, AlertDecodeError - } - - rawCerts := make([][]byte, len(state.clientCertificate.CertificateList)) - certs := make([]*x509.Certificate, len(state.clientCertificate.CertificateList)) - for i, certEntry := range state.clientCertificate.CertificateList { - certs[i] = certEntry.CertData - rawCerts[i] = certEntry.CertData.Raw - } - - // Verify client signature over handshake hash - hcv := state.handshakeHash.Sum(nil) - logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - - clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey - if err := certVerify.Verify(clientPublicKey, hcv); err != nil { - logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err) - return nil, nil, AlertHandshakeFailure - } - - if state.Config.VerifyPeerCertificate != nil { - // TODO(#171): pass in the verified chains, once we support different client auth types - if err := state.Config.VerifyPeerCertificate(rawCerts, nil); err != nil { - logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate: %s", err) - return nil, nil, AlertBadCertificate - } - } - - // If it passes, record the certificateVerify in the transcript hash - state.handshakeHash.Write(hm.Marshal()) - - logf(logTypeHandshake, "[ServerStateWaitCV] -> [ServerStateWaitFinished]") - nextState := serverStateWaitFinished{ - Params: state.Params, - hsCtx: state.hsCtx, - cryptoParams: state.cryptoParams, - masterSecret: state.masterSecret, - clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, - handshakeHash: state.handshakeHash, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - peerCertificates: certs, - verifiedChains: nil, // TODO(#171): set this value - } - return nextState, nil, AlertNoAlert -} - -type serverStateWaitFinished struct { - Params ConnectionParameters - hsCtx *HandshakeContext - cryptoParams CipherSuiteParams - - masterSecret []byte - clientHandshakeTrafficSecret []byte - peerCertificates []*x509.Certificate - verifiedChains [][]*x509.Certificate - - handshakeHash hash.Hash - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte -} - -var _ HandshakeState = &serverStateWaitFinished{} - -func (state serverStateWaitFinished) State() State { - return StateServerWaitFinished -} - -func (state serverStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - hm, alert := hr.ReadMessage() - if alert != AlertNoAlert { - return nil, nil, alert - } - if hm == nil || hm.msgType != HandshakeTypeFinished { - logf(logTypeHandshake, "[ServerStateWaitFinished] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - fin := &FinishedBody{VerifyDataLen: state.cryptoParams.Hash.Size()} - if err := safeUnmarshal(fin, hm.body); err != nil { - logf(logTypeHandshake, "[ServerStateWaitFinished] Error decoding message %v", err) - return nil, nil, AlertDecodeError - } - - // Verify client Finished data - h5 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash for client Finished: [%d] %x", len(h5), h5) - - clientFinishedData := computeFinishedData(state.cryptoParams, state.clientHandshakeTrafficSecret, h5) - logf(logTypeCrypto, "client Finished data: [%d] %x", len(clientFinishedData), clientFinishedData) - - if !bytes.Equal(fin.VerifyData, clientFinishedData) { - logf(logTypeHandshake, "[ServerStateWaitFinished] Client's Finished failed to verify") - return nil, nil, AlertHandshakeFailure - } - - // Compute the resumption secret - state.handshakeHash.Write(hm.Marshal()) - h6 := state.handshakeHash.Sum(nil) - logf(logTypeCrypto, "handshake hash 6 [%d]: %x", len(h6), h6) - - resumptionSecret := deriveSecret(state.cryptoParams, state.masterSecret, labelResumptionSecret, h6) - logf(logTypeCrypto, "resumption secret: [%d] %x", len(resumptionSecret), resumptionSecret) - - // Compute client traffic keys - clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) - - state.hsCtx.receivedFinalFlight() - - logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") - nextState := stateConnected{ - Params: state.Params, - hsCtx: state.hsCtx, - isClient: false, - cryptoParams: state.cryptoParams, - resumptionSecret: resumptionSecret, - clientTrafficSecret: state.clientTrafficSecret, - serverTrafficSecret: state.serverTrafficSecret, - exporterSecret: state.exporterSecret, - peerCertificates: state.peerCertificates, - verifiedChains: state.verifiedChains, - } - toSend := []HandshakeAction{ - RekeyIn{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, - } - return nextState, toSend, AlertNoAlert -} diff --git a/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/bifurcation/mint/state-machine.go deleted file mode 100644 index 558b76cc..00000000 --- a/vendor/github.com/bifurcation/mint/state-machine.go +++ /dev/null @@ -1,247 +0,0 @@ -package mint - -import ( - "crypto/x509" - "time" -) - -// Marker interface for actions that an implementation should take based on -// state transitions. -type HandshakeAction interface{} - -type QueueHandshakeMessage struct { - Message *HandshakeMessage -} - -type SendQueuedHandshake struct{} - -type SendEarlyData struct{} - -type RekeyIn struct { - epoch Epoch - KeySet keySet -} - -type RekeyOut struct { - epoch Epoch - KeySet keySet -} - -type ResetOut struct { - seq uint64 -} - -type StorePSK struct { - PSK PreSharedKey -} - -type HandshakeState interface { - Next(handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) - State() State -} - -type AppExtensionHandler interface { - Send(hs HandshakeType, el *ExtensionList) error - Receive(hs HandshakeType, el *ExtensionList) error -} - -// ConnectionOptions objects represent per-connection settings for a client -// initiating a connection -type ConnectionOptions struct { - ServerName string - NextProtos []string -} - -// ConnectionParameters objects represent the parameters negotiated for a -// connection. -type ConnectionParameters struct { - UsingPSK bool - UsingDH bool - ClientSendingEarlyData bool - UsingEarlyData bool - RejectedEarlyData bool - UsingClientAuth bool - - CipherSuite CipherSuite - ServerName string - NextProto string -} - -// Working state for the handshake. -type HandshakeContext struct { - timeoutMS uint32 - timers *timerSet - recvdRecords []uint64 - sentFragments []*SentHandshakeFragment - hIn, hOut *HandshakeLayer - waitingNextFlight bool - earlyData []byte -} - -func (hc *HandshakeContext) SetVersion(version uint16) { - if hc.hIn.conn != nil { - hc.hIn.conn.SetVersion(version) - } - if hc.hOut.conn != nil { - hc.hOut.conn.SetVersion(version) - } -} - -// stateConnected is symmetric between client and server -type stateConnected struct { - Params ConnectionParameters - hsCtx *HandshakeContext - isClient bool - cryptoParams CipherSuiteParams - resumptionSecret []byte - clientTrafficSecret []byte - serverTrafficSecret []byte - exporterSecret []byte - peerCertificates []*x509.Certificate - verifiedChains [][]*x509.Certificate -} - -var _ HandshakeState = &stateConnected{} - -func (state stateConnected) State() State { - if state.isClient { - return StateClientConnected - } - return StateServerConnected -} - -func (state *stateConnected) KeyUpdate(request KeyUpdateRequest) ([]HandshakeAction, Alert) { - var trafficKeys keySet - if state.isClient { - state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, - labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) - trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) - } else { - state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, - labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) - trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) - } - - kum, err := state.hsCtx.hOut.HandshakeMessageFromBody(&KeyUpdateBody{KeyUpdateRequest: request}) - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error marshaling key update message: %v", err) - return nil, AlertInternalError - } - - toSend := []HandshakeAction{ - QueueHandshakeMessage{kum}, - SendQueuedHandshake{}, - RekeyOut{epoch: EpochUpdate, KeySet: trafficKeys}, - } - return toSend, AlertNoAlert -} - -func (state *stateConnected) NewSessionTicket(length int, lifetime, earlyDataLifetime uint32) ([]HandshakeAction, Alert) { - tkt, err := NewSessionTicket(length, lifetime) - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error generating NewSessionTicket: %v", err) - return nil, AlertInternalError - } - - err = tkt.Extensions.Add(&TicketEarlyDataInfoExtension{earlyDataLifetime}) - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error adding extension to NewSessionTicket: %v", err) - return nil, AlertInternalError - } - - resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, - labelResumption, tkt.TicketNonce, state.cryptoParams.Hash.Size()) - - newPSK := PreSharedKey{ - CipherSuite: state.cryptoParams.Suite, - IsResumption: true, - Identity: tkt.Ticket, - Key: resumptionKey, - NextProto: state.Params.NextProto, - ReceivedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Duration(tkt.TicketLifetime) * time.Second), - TicketAgeAdd: tkt.TicketAgeAdd, - } - - tktm, err := state.hsCtx.hOut.HandshakeMessageFromBody(tkt) - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error marshaling NewSessionTicket: %v", err) - return nil, AlertInternalError - } - - toSend := []HandshakeAction{ - StorePSK{newPSK}, - QueueHandshakeMessage{tktm}, - SendQueuedHandshake{}, - } - return toSend, AlertNoAlert -} - -// Next does nothing for this state. -func (state stateConnected) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { - return state, nil, AlertNoAlert -} - -func (state stateConnected) ProcessMessage(hm *HandshakeMessage) (HandshakeState, []HandshakeAction, Alert) { - if hm == nil { - logf(logTypeHandshake, "[StateConnected] Unexpected message") - return nil, nil, AlertUnexpectedMessage - } - - bodyGeneric, err := hm.ToBody() - if err != nil { - logf(logTypeHandshake, "[StateConnected] Error decoding message: %v", err) - return nil, nil, AlertDecodeError - } - - switch body := bodyGeneric.(type) { - case *KeyUpdateBody: - var trafficKeys keySet - if !state.isClient { - state.clientTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.clientTrafficSecret, - labelClientApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) - trafficKeys = makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) - } else { - state.serverTrafficSecret = HkdfExpandLabel(state.cryptoParams.Hash, state.serverTrafficSecret, - labelServerApplicationTrafficSecret, []byte{}, state.cryptoParams.Hash.Size()) - trafficKeys = makeTrafficKeys(state.cryptoParams, state.serverTrafficSecret) - } - - toSend := []HandshakeAction{RekeyIn{epoch: EpochUpdate, KeySet: trafficKeys}} - - // If requested, roll outbound keys and send a KeyUpdate - if body.KeyUpdateRequest == KeyUpdateRequested { - logf(logTypeHandshake, "Received key update, update requested", body.KeyUpdateRequest) - moreToSend, alert := state.KeyUpdate(KeyUpdateNotRequested) - if alert != AlertNoAlert { - return nil, nil, alert - } - toSend = append(toSend, moreToSend...) - } - return state, toSend, AlertNoAlert - case *NewSessionTicketBody: - // XXX: Allow NewSessionTicket in both directions? - if !state.isClient { - return nil, nil, AlertUnexpectedMessage - } - - resumptionKey := HkdfExpandLabel(state.cryptoParams.Hash, state.resumptionSecret, - labelResumption, body.TicketNonce, state.cryptoParams.Hash.Size()) - psk := PreSharedKey{ - CipherSuite: state.cryptoParams.Suite, - IsResumption: true, - Identity: body.Ticket, - Key: resumptionKey, - NextProto: state.Params.NextProto, - ReceivedAt: time.Now(), - ExpiresAt: time.Now().Add(time.Duration(body.TicketLifetime) * time.Second), - TicketAgeAdd: body.TicketAgeAdd, - } - - toSend := []HandshakeAction{StorePSK{psk}} - return state, toSend, AlertNoAlert - } - - logf(logTypeHandshake, "[StateConnected] Unexpected message type %v", hm.msgType) - return nil, nil, AlertUnexpectedMessage -} diff --git a/vendor/github.com/bifurcation/mint/timer.go b/vendor/github.com/bifurcation/mint/timer.go deleted file mode 100644 index 0b7f7aff..00000000 --- a/vendor/github.com/bifurcation/mint/timer.go +++ /dev/null @@ -1,122 +0,0 @@ -package mint - -import ( - "time" -) - -// This is a simple timer implementation. Timers are stored in a sorted -// list. -// TODO(ekr@rtfm.com): Add a way to uncouple these from the system -// clock. -type timerCb func() error - -type timer struct { - label string - cb timerCb - deadline time.Time - duration uint32 -} - -type timerSet struct { - ts []*timer -} - -func newTimerSet() *timerSet { - return &timerSet{} -} - -func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer { - now := time.Now() - t := timer{ - label, - cb, - now.Add(time.Millisecond * time.Duration(delayMs)), - delayMs, - } - logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline) - - var i int - ntimers := len(ts.ts) - for i = 0; i < ntimers; i++ { - if t.deadline.Before(ts.ts[i].deadline) { - break - } - } - - tmp := make([]*timer, 0, ntimers+1) - tmp = append(tmp, ts.ts[:i]...) - tmp = append(tmp, &t) - tmp = append(tmp, ts.ts[i:]...) - ts.ts = tmp - - return &t -} - -// TODO(ekr@rtfm.com): optimize this now that the list is sorted. -// We should be able to do just one list manipulation, as long -// as we're careful about how we handle inserts during callbacks. -func (ts *timerSet) check(now time.Time) error { - for i, t := range ts.ts { - if now.After(t.deadline) { - ts.ts = append(ts.ts[:i], ts.ts[:i+1]...) - if t.cb != nil { - logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline) - cb := t.cb - t.cb = nil - err := cb() - if err != nil { - return err - } - } - } else { - break - } - } - return nil -} - -// Returns the next time any of the timers would fire. -func (ts *timerSet) remaining() (bool, time.Duration) { - for _, t := range ts.ts { - if t.cb != nil { - return true, time.Until(t.deadline) - } - } - - return false, time.Duration(0) -} - -func (ts *timerSet) cancel(label string) { - for _, t := range ts.ts { - if t.label == label { - t.cancel() - } - } -} - -func (ts *timerSet) getTimer(label string) *timer { - for _, t := range ts.ts { - if t.label == label && t.cb != nil { - return t - } - } - return nil -} - -func (ts *timerSet) getAllTimers() []string { - var ret []string - - for _, t := range ts.ts { - if t.cb != nil { - ret = append(ret, t.label) - } - } - - return ret -} - -func (t *timer) cancel() { - logf(logTypeHandshake, "Timer %s cancelled", t.label) - t.cb = nil - t.label = "" -} diff --git a/vendor/github.com/bifurcation/mint/tls.go b/vendor/github.com/bifurcation/mint/tls.go deleted file mode 100644 index 4d228692..00000000 --- a/vendor/github.com/bifurcation/mint/tls.go +++ /dev/null @@ -1,179 +0,0 @@ -package mint - -// XXX(rlb): This file is borrowed pretty much wholesale from crypto/tls - -import ( - "errors" - "net" - "strings" - "time" -) - -// Server returns a new TLS server side connection -// using conn as the underlying transport. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Server(conn net.Conn, config *Config) *Conn { - return NewConn(conn, config, false) -} - -// Client returns a new TLS client side connection -// using conn as the underlying transport. -// The config cannot be nil: users must set either ServerName or -// InsecureSkipVerify in the config. -func Client(conn net.Conn, config *Config) *Conn { - return NewConn(conn, config, true) -} - -// A listener implements a network listener (net.Listener) for TLS connections. -type Listener struct { - net.Listener - config *Config -} - -// Accept waits for and returns the next incoming TLS connection. -// The returned connection c is a *tls.Conn. -func (l *Listener) Accept() (c net.Conn, err error) { - c, err = l.Listener.Accept() - if err != nil { - return - } - server := Server(c, l.config) - err = server.Handshake() - if err == AlertNoAlert { - err = nil - } - c = server - return -} - -// NewListener creates a Listener which accepts connections from an inner -// Listener and wraps each connection with Server. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func NewListener(inner net.Listener, config *Config) (net.Listener, error) { - if config != nil && config.NonBlocking { - return nil, errors.New("listening not possible in non-blocking mode") - } - l := new(Listener) - l.Listener = inner - l.config = config - return l, nil -} - -// Listen creates a TLS listener accepting connections on the -// given network address using net.Listen. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Listen(network, laddr string, config *Config) (net.Listener, error) { - if config == nil || !config.ValidForServer() { - return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config") - } - l, err := net.Listen(network, laddr) - if err != nil { - return nil, err - } - return NewListener(l, config) -} - -type TimeoutError struct{} - -func (TimeoutError) Error() string { return "tls: DialWithDialer timed out" } -func (TimeoutError) Timeout() bool { return true } -func (TimeoutError) Temporary() bool { return true } - -// DialWithDialer connects to the given network address using dialer.Dial and -// then initiates a TLS handshake, returning the resulting TLS connection. Any -// timeout or deadline given in the dialer apply to connection and TLS -// handshake as a whole. -// -// DialWithDialer interprets a nil configuration as equivalent to the zero -// configuration; see the documentation of Config for the defaults. -func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { - if config != nil && config.NonBlocking { - return nil, errors.New("dialing not possible in non-blocking mode") - } - - // We want the Timeout and Deadline values from dialer to cover the - // whole process: TCP connection and TLS handshake. This means that we - // also need to start our own timers now. - timeout := dialer.Timeout - - if !dialer.Deadline.IsZero() { - deadlineTimeout := dialer.Deadline.Sub(time.Now()) - if timeout == 0 || deadlineTimeout < timeout { - timeout = deadlineTimeout - } - } - - var errChannel chan error - - if timeout != 0 { - errChannel = make(chan error, 2) - time.AfterFunc(timeout, func() { - errChannel <- TimeoutError{} - }) - } - - rawConn, err := dialer.Dial(network, addr) - if err != nil { - return nil, err - } - - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - - if config == nil { - config = &Config{} - } else { - config = config.Clone() - } - - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - if config.ServerName == "" { - config.ServerName = hostname - - } - - // Set up DTLS as needed. - config.UseDTLS = (network == "udp") - - conn := Client(rawConn, config) - - if timeout == 0 { - err = conn.Handshake() - if err == AlertNoAlert { - err = nil - } - } else { - go func() { - errChannel <- conn.Handshake() - }() - - err = <-errChannel - if err == AlertNoAlert { - err = nil - } - } - - if err != nil { - rawConn.Close() - return nil, err - } - - return conn, nil -} - -// Dial connects to the given network address using net.Dial -// and then initiates a TLS handshake, returning the resulting -// TLS connection. -// Dial interprets a nil configuration as equivalent to -// the zero configuration; see the documentation of Config -// for the defaults. -func Dial(network, addr string, config *Config) (*Conn, error) { - return DialWithDialer(new(net.Dialer), network, addr, config) -} diff --git a/vendor/github.com/marten-seemann/qtls/13.go b/vendor/github.com/marten-seemann/qtls/13.go new file mode 100644 index 00000000..304f6691 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/13.go @@ -0,0 +1,1162 @@ +package qtls + +import ( + "bytes" + "crypto" + "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/hmac" + "crypto/rsa" + "crypto/subtle" + "encoding/hex" + "errors" + "fmt" + "hash" + "io" + "log" + "os" + "runtime" + "runtime/debug" + "strings" + "sync/atomic" + "time" + + "golang.org/x/crypto/curve25519" +) + +// numSessionTickets is the number of different session tickets the +// server sends to a TLS 1.3 client, who will use each only once. +const numSessionTickets = 2 + +type secretLabel int + +const ( + secretResumptionPskBinder secretLabel = iota + secretEarlyClient + secretHandshakeClient + secretHandshakeServer + secretApplicationClient + secretApplicationServer + secretResumption +) + +type keySchedule13 struct { + suite *cipherSuite + transcriptHash hash.Hash // uses the cipher suite hash algo + secret []byte // Current secret as used for Derive-Secret + handshakeCtx []byte // cached handshake context, invalidated on updates. + clientRandom []byte // Used for keylogging, nil if keylogging is disabled. + config *Config // Used for KeyLogWriter callback, nil if keylogging is disabled. +} + +func newKeySchedule13(suite *cipherSuite, config *Config, clientRandom []byte) *keySchedule13 { + if config.KeyLogWriter == nil { + clientRandom = nil + config = nil + } + return &keySchedule13{ + suite: suite, + transcriptHash: hashForSuite(suite).New(), + clientRandom: clientRandom, + config: config, + } +} + +// setSecret sets the early/handshake/master secret based on the given secret +// (IKM). The salt is based on previous secrets (nil for the early secret). +func (ks *keySchedule13) setSecret(secret []byte) { + hash := hashForSuite(ks.suite) + salt := ks.secret + if salt != nil { + h0 := hash.New().Sum(nil) + salt = hkdfExpandLabel(hash, salt, h0, "derived", hash.Size()) + } + ks.secret = hkdfExtract(hash, secret, salt) +} + +// write appends the data to the transcript hash context. +func (ks *keySchedule13) write(data []byte) { + ks.handshakeCtx = nil + ks.transcriptHash.Write(data) +} + +func (ks *keySchedule13) getLabel(secretLabel secretLabel) (label, keylogType string) { + switch secretLabel { + case secretResumptionPskBinder: + label = "res binder" + case secretEarlyClient: + label = "c e traffic" + keylogType = "CLIENT_EARLY_TRAFFIC_SECRET" + case secretHandshakeClient: + label = "c hs traffic" + keylogType = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" + case secretHandshakeServer: + label = "s hs traffic" + keylogType = "SERVER_HANDSHAKE_TRAFFIC_SECRET" + case secretApplicationClient: + label = "c ap traffic" + keylogType = "CLIENT_TRAFFIC_SECRET_0" + case secretApplicationServer: + label = "s ap traffic" + keylogType = "SERVER_TRAFFIC_SECRET_0" + case secretResumption: + label = "res master" + } + return +} + +// deriveSecret returns the secret derived from the handshake context and label. +func (ks *keySchedule13) deriveSecret(secretLabel secretLabel) []byte { + label, keylogType := ks.getLabel(secretLabel) + if ks.handshakeCtx == nil { + ks.handshakeCtx = ks.transcriptHash.Sum(nil) + } + hash := hashForSuite(ks.suite) + secret := hkdfExpandLabel(hash, ks.secret, ks.handshakeCtx, label, hash.Size()) + if keylogType != "" && ks.config != nil { + ks.config.writeKeyLog(keylogType, ks.clientRandom, secret) + } + return secret +} + +func (ks *keySchedule13) prepareCipher(trafficSecret []byte) cipher.AEAD { + hash := hashForSuite(ks.suite) + key := hkdfExpandLabel(hash, trafficSecret, nil, "key", ks.suite.keyLen) + iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", ks.suite.ivLen) + return ks.suite.aead(key, iv) +} + +func (hs *serverHandshakeState) doTLS13Handshake() error { + config := hs.c.config + c := hs.c + + hs.c.cipherSuite, hs.hello.cipherSuite = hs.suite.id, hs.suite.id + hs.c.clientHello = hs.clientHello.marshal() + + // When picking the group for the handshake, priority is given to groups + // that the client provided a keyShare for, so to avoid a round-trip. + // After that the order of CurvePreferences is respected. + var ks keyShare +CurvePreferenceLoop: + for _, curveID := range config.curvePreferences() { + for _, keyShare := range hs.clientHello.keyShares { + if curveID == keyShare.group { + ks = keyShare + break CurvePreferenceLoop + } + } + } + if ks.group == 0 { + c.sendAlert(alertInternalError) + return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo) + } + + privateKey, serverKS, err := config.generateKeyShare(ks.group) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.keyShare = serverKS + + hash := hashForSuite(hs.suite) + hashSize := hash.Size() + hs.keySchedule = newKeySchedule13(hs.suite, config, hs.clientHello.random) + + // Check for PSK and update key schedule with new early secret key + isResumed, pskAlert := hs.checkPSK() + switch { + case pskAlert != alertSuccess: + c.sendAlert(pskAlert) + return errors.New("tls: invalid client PSK") + case !isResumed: + // apply an empty PSK if not resumed. + hs.keySchedule.setSecret(nil) + case isResumed: + c.didResume = true + } + + hs.keySchedule.write(hs.clientHello.marshal()) + + earlyClientTrafficSecret := hs.keySchedule.deriveSecret(secretEarlyClient) + + ecdheSecret := deriveECDHESecret(ks, privateKey) + if ecdheSecret == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: bad ECDHE client share") + } + + hs.keySchedule.write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + // middlebox compatibility mode: send CCS after first handshake message + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + hs.keySchedule.setSecret(ecdheSecret) + hs.hsClientTrafficSecret = hs.keySchedule.deriveSecret(secretHandshakeClient) + hsServerTrafficSecret := hs.keySchedule.deriveSecret(secretHandshakeServer) + c.out.setKey(c.vers, hs.keySchedule.suite, hsServerTrafficSecret) + + serverFinishedKey := hkdfExpandLabel(hash, hsServerTrafficSecret, nil, "finished", hashSize) + hs.clientFinishedKey = hkdfExpandLabel(hash, hs.hsClientTrafficSecret, nil, "finished", hashSize) + + // EncryptedExtensions + hs.keySchedule.write(hs.hello13Enc.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { + return err + } + + // TODO: we should have 2 separated methods - one for full-handshake and the other for PSK-handshake + if !c.didResume { + // Server MUST NOT send CertificateRequest if authenticating with PSK + if c.config.ClientAuth >= RequestClientCert { + + certReq := new(certificateRequestMsg13) + // extension 'signature_algorithms' MUST be specified + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms13 + certReq.supportedSignatureAlgorithmsCert = supportedSigAlgorithmsCert(supportedSignatureAlgorithms13) + hs.keySchedule.write(certReq.marshal()) + if _, err := hs.c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + if err := hs.sendCertificate13(); err != nil { + return err + } + } + + verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, serverFinishedKey) + serverFinished := &finishedMsg{ + verifyData: verifyData, + } + hs.keySchedule.write(serverFinished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil { + return err + } + + hs.keySchedule.setSecret(nil) // derive master secret + serverAppTrafficSecret := hs.keySchedule.deriveSecret(secretApplicationServer) + c.out.setKey(c.vers, hs.keySchedule.suite, serverAppTrafficSecret) + + if c.hand.Len() > 0 { + return c.sendAlert(alertUnexpectedMessage) + } + hs.appClientTrafficSecret = hs.keySchedule.deriveSecret(secretApplicationClient) + if hs.hello13Enc.earlyData { + c.in.setKey(c.vers, hs.keySchedule.suite, earlyClientTrafficSecret) + c.phase = readingEarlyData + } else { + c.in.setKey(c.vers, hs.keySchedule.suite, hs.hsClientTrafficSecret) + if hs.clientHello.earlyData { + c.phase = discardingEarlyData + } else { + c.phase = waitingClientFinished + } + } + + return nil +} + +// readClientFinished13 is called during the server handshake (when no early +// data it available) or after reading all early data. It discards early data if +// the server did not accept it and then verifies the Finished message. Once +// done it sends the session tickets. Under c.in lock. +func (hs *serverHandshakeState) readClientFinished13(hasConfirmLock bool) error { + c := hs.c + + // If the client advertised and sends early data while the server does + // not accept it, it must be fully skipped until the Finished message. + for c.phase == discardingEarlyData { + if err := c.readRecord(recordTypeApplicationData); err != nil { + return err + } + // Assume receipt of Finished message (will be checked below). + if c.hand.Len() > 0 { + c.phase = waitingClientFinished + break + } + } + + // If the client sends early data followed by a Finished message (but + // no end_of_early_data), the server MUST terminate the connection. + if c.phase != waitingClientFinished { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: did not expect Client Finished yet") + } + + c.phase = readingClientFinished + msg, err := c.readHandshake() + if err != nil { + return err + } + + // client authentication + if certMsg, ok := msg.(*certificateMsg13); ok { + + // (4.4.2) Client MUST send certificate msg if requested by server + if c.config.ClientAuth < RequestClientCert { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + + hs.keySchedule.write(certMsg.marshal()) + certs := getCertsFromEntries(certMsg.certificates) + pubKey, err := hs.processCertsFromClient(certs) + if err != nil { + return err + } + + // 4.4.3: CertificateVerify MUST appear immediately after Certificate msg + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + err, alertCode := verifyPeerHandshakeSignature( + certVerify, + pubKey, + supportedSignatureAlgorithms13, + hs.keySchedule.transcriptHash.Sum(nil), + "TLS 1.3, client CertificateVerify") + if err != nil { + c.sendAlert(alertCode) + return err + } + hs.keySchedule.write(certVerify.marshal()) + + // Read next chunk + msg, err = c.readHandshake() + if err != nil { + return err + } + + } else if (c.config.ClientAuth >= RequestClientCert) && !c.didResume { + c.sendAlert(alertCertificateRequired) + return unexpectedMessageError(certMsg, msg) + } + + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + + hash := hashForSuite(hs.suite) + expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, hs.clientFinishedKey) + if len(expectedVerifyData) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { + c.sendAlert(alertDecryptError) + return errors.New("tls: client's Finished message is incorrect") + } + hs.keySchedule.write(clientFinished.marshal()) + + c.hs = nil // Discard the server handshake state + if c.hand.Len() > 0 { + return c.sendAlert(alertUnexpectedMessage) + } + c.in.setKey(c.vers, hs.keySchedule.suite, hs.appClientTrafficSecret) + c.in.traceErr, c.out.traceErr = nil, nil + c.phase = handshakeConfirmed + atomic.StoreInt32(&c.handshakeConfirmed, 1) + + // Any read operation after handshakeRunning and before handshakeConfirmed + // will be holding this lock, which we release as soon as the confirmation + // happens, even if the Read call might do more work. + // If a Handshake is pending, c.confirmMutex will never be locked as + // ConfirmHandshake will wait for the handshake to complete. If a + // handshake was complete, and this was a confirmation, unlock + // c.confirmMutex now to allow readers to proceed. + if hasConfirmLock { + c.confirmMutex.Unlock() + } + + return hs.sendSessionTicket13() // TODO: do in a goroutine +} + +func (hs *serverHandshakeState) sendCertificate13() error { + c := hs.c + + certEntries := []certificateEntry{} + for _, cert := range hs.cert.Certificate { + certEntries = append(certEntries, certificateEntry{data: cert}) + } + if len(certEntries) > 0 && hs.clientHello.ocspStapling { + certEntries[0].ocspStaple = hs.cert.OCSPStaple + } + if len(certEntries) > 0 && hs.clientHello.scts { + certEntries[0].sctList = hs.cert.SignedCertificateTimestamps + } + + // If hs.delegatedCredential is set (see hs.readClientHello()) then the + // server is using the delegated credential extension. The DC is added as an + // extension to the end-entity certificate, i.e., the last CertificateEntry + // of Certificate.certficate_list. (For details, see + // https://tools.ietf.org/html/draft-ietf-tls-subcerts-02.) + if len(certEntries) > 0 && hs.clientHello.delegatedCredential && hs.delegatedCredential != nil { + certEntries[0].delegatedCredential = hs.delegatedCredential + } + + certMsg := &certificateMsg13{certificates: certEntries} + + hs.keySchedule.write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + sigScheme, err := hs.selectTLS13SignatureScheme() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + sigHash := hashForSignatureScheme(sigScheme) + opts := crypto.SignerOpts(sigHash) + if signatureSchemeIsPSS(sigScheme) { + opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + + toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.keySchedule.transcriptHash.Sum(nil)) + signature, err := hs.privateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + verifyMsg := &certificateVerifyMsg{ + hasSignatureAndHash: true, + signatureAlgorithm: sigScheme, + signature: signature, + } + hs.keySchedule.write(verifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (c *Conn) handleEndOfEarlyData() error { + if c.phase != readingEarlyData || c.vers < VersionTLS13 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + msg, err := c.readHandshake() + if err != nil { + return err + } + endOfEarlyData, ok := msg.(*endOfEarlyDataMsg) + // No handshake messages are allowed after EOD. + if !ok || c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.hs.keySchedule.write(endOfEarlyData.marshal()) + c.phase = waitingClientFinished + c.in.setKey(c.vers, c.hs.keySchedule.suite, c.hs.hsClientTrafficSecret) + return nil +} + +// selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify +// based on the certificate type and client supported schemes. If no overlap is found, +// a fallback is selected. +// +// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.4.1.2 +func (hs *serverHandshakeState) selectTLS13SignatureScheme() (sigScheme SignatureScheme, err error) { + var supportedSchemes []SignatureScheme + signer, ok := hs.privateKey.(crypto.Signer) + if !ok { + return 0, errors.New("tls: private key does not implement crypto.Signer") + } + pk := signer.Public() + if _, ok := pk.(*rsa.PublicKey); ok { + sigScheme = PSSWithSHA256 + supportedSchemes = []SignatureScheme{PSSWithSHA256, PSSWithSHA384, PSSWithSHA512} + } else if pk, ok := pk.(*ecdsa.PublicKey); ok { + switch pk.Curve { + case elliptic.P256(): + sigScheme = ECDSAWithP256AndSHA256 + supportedSchemes = []SignatureScheme{ECDSAWithP256AndSHA256} + case elliptic.P384(): + sigScheme = ECDSAWithP384AndSHA384 + supportedSchemes = []SignatureScheme{ECDSAWithP384AndSHA384} + case elliptic.P521(): + sigScheme = ECDSAWithP521AndSHA512 + supportedSchemes = []SignatureScheme{ECDSAWithP521AndSHA512} + default: + return 0, errors.New("tls: unknown ECDSA certificate curve") + } + } else { + return 0, errors.New("tls: unknown certificate key type") + } + + for _, ss := range supportedSchemes { + for _, cs := range hs.clientHello.supportedSignatureAlgorithms { + if ss == cs { + return ss, nil + } + } + } + + return sigScheme, nil +} + +func signatureSchemeIsPSS(s SignatureScheme) bool { + return s == PSSWithSHA256 || s == PSSWithSHA384 || s == PSSWithSHA512 +} + +// hashForSignatureScheme returns the Hash used by a SignatureScheme which is +// supported by selectTLS13SignatureScheme. +func hashForSignatureScheme(ss SignatureScheme) crypto.Hash { + switch ss { + case PSSWithSHA256, ECDSAWithP256AndSHA256: + return crypto.SHA256 + case PSSWithSHA384, ECDSAWithP384AndSHA384: + return crypto.SHA384 + case PSSWithSHA512, ECDSAWithP521AndSHA512: + return crypto.SHA512 + default: + panic("unsupported SignatureScheme passed to hashForSignatureScheme") + } +} + +func hashForSuite(suite *cipherSuite) crypto.Hash { + if suite.flags&suiteSHA384 != 0 { + return crypto.SHA384 + } + return crypto.SHA256 +} + +func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byte { + message := bytes.Repeat([]byte{32}, 64) + message = append(message, context...) + message = append(message, 0) + message = append(message, data...) + h := hash.New() + h.Write(message) + return h.Sum(nil) +} + +func (c *Config) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) { + if curveID == X25519 { + var scalar, public [32]byte + if _, err := io.ReadFull(c.rand(), scalar[:]); err != nil { + return nil, keyShare{}, err + } + + curve25519.ScalarBaseMult(&public, &scalar) + return scalar[:], keyShare{group: curveID, data: public[:]}, nil + } + + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve") + } + + privateKey, x, y, err := elliptic.GenerateKey(curve, c.rand()) + if err != nil { + return nil, keyShare{}, err + } + ecdhePublic := elliptic.Marshal(curve, x, y) + + return privateKey, keyShare{group: curveID, data: ecdhePublic}, nil +} + +func deriveECDHESecret(ks keyShare, secretKey []byte) []byte { + if ks.group == X25519 { + if len(ks.data) != 32 { + return nil + } + + var theirPublic, sharedKey, scalar [32]byte + copy(theirPublic[:], ks.data) + copy(scalar[:], secretKey) + curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic) + return sharedKey[:] + } + + curve, ok := curveForCurveID(ks.group) + if !ok { + return nil + } + x, y := elliptic.Unmarshal(curve, ks.data) + if x == nil { + return nil + } + x, _ = curve.ScalarMult(x, y, secretKey) + xBytes := x.Bytes() + curveSize := (curve.Params().BitSize + 8 - 1) >> 3 + if len(xBytes) == curveSize { + return xBytes + } + buf := make([]byte, curveSize) + copy(buf[len(buf)-len(xBytes):], xBytes) + return buf +} + +func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte { + prefix := "tls13 " + hkdfLabel := make([]byte, 4+len(prefix)+len(label)+len(hashValue)) + hkdfLabel[0] = byte(L >> 8) + hkdfLabel[1] = byte(L) + hkdfLabel[2] = byte(len(prefix) + len(label)) + copy(hkdfLabel[3:], prefix) + z := hkdfLabel[3+len(prefix):] + copy(z, label) + z = z[len(label):] + z[0] = byte(len(hashValue)) + copy(z[1:], hashValue) + + return hkdfExpand(hash, secret, hkdfLabel, L) +} + +func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte { + h := hmac.New(f.New, key) + h.Write(hash.Sum(nil)) + return h.Sum(nil) +} + +// Maximum allowed mismatch between the stated age of a ticket +// and the server-observed one. See +// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. +const ticketAgeSkewAllowance = 10 * time.Second + +// checkPSK tries to resume using a PSK, returning true (and updating the +// early secret in the key schedule) if the PSK was used and false otherwise. +func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) { + if hs.c.config.SessionTicketsDisabled { + return false, alertSuccess + } + + foundDHE := false + for _, mode := range hs.clientHello.pskKeyExchangeModes { + if mode == pskDHEKeyExchange { + foundDHE = true + break + } + } + if !foundDHE { + return false, alertSuccess + } + + hash := hashForSuite(hs.suite) + hashSize := hash.Size() + for i := range hs.clientHello.psks { + sessionTicket := append([]uint8{}, hs.clientHello.psks[i].identity...) + if hs.c.config.SessionTicketSealer != nil { + var ok bool + sessionTicket, ok = hs.c.config.SessionTicketSealer.Unseal(hs.clientHelloInfo(), sessionTicket) + if !ok { + continue + } + } else { + sessionTicket, _ = hs.c.decryptTicket(sessionTicket) + if sessionTicket == nil { + continue + } + } + s := &sessionState13{} + if s.unmarshal(sessionTicket) != alertSuccess { + continue + } + if s.vers != hs.c.vers { + continue + } + clientAge := time.Duration(hs.clientHello.psks[i].obfTicketAge-s.ageAdd) * time.Millisecond + serverAge := time.Since(time.Unix(int64(s.createdAt), 0)) + if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance { + // XXX: NSS is off spec and sends obfuscated_ticket_age as seconds + clientAge = time.Duration(hs.clientHello.psks[i].obfTicketAge-s.ageAdd) * time.Second + if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance { + continue + } + } + + // This enforces the stricter 0-RTT requirements on all ticket uses. + // The benefit of using PSK+ECDHE without 0-RTT are small enough that + // we can give them up in the edge case of changed suite or ALPN or SNI. + if s.suite != hs.suite.id { + continue + } + if s.alpnProtocol != hs.c.clientProtocol { + continue + } + if s.SNI != hs.c.serverName { + continue + } + + hs.keySchedule.setSecret(s.pskSecret) + binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder) + binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) + chHash := hash.New() + chHash.Write(hs.clientHello.rawTruncated) + expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) + + if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 { + return false, alertDecryptError + } + + if i == 0 && hs.clientHello.earlyData { + // This is a ticket intended to be used for 0-RTT + if s.maxEarlyDataLen == 0 { + // But we had not tagged it as such. + return false, alertIllegalParameter + } + if hs.c.config.Accept0RTTData { + hs.c.binder = expectedBinder + hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen) + hs.hello13Enc.earlyData = true + } + } + hs.hello.psk = true + hs.hello.pskIdentity = uint16(i) + return true, alertSuccess + } + + return false, alertSuccess +} + +func (hs *serverHandshakeState) sendSessionTicket13() error { + c := hs.c + if c.config.SessionTicketsDisabled { + return nil + } + + foundDHE := false + for _, mode := range hs.clientHello.pskKeyExchangeModes { + if mode == pskDHEKeyExchange { + foundDHE = true + break + } + } + if !foundDHE { + return nil + } + + resumptionMasterSecret := hs.keySchedule.deriveSecret(secretResumption) + + ageAddBuf := make([]byte, 4) + sessionState := &sessionState13{ + vers: c.vers, + suite: hs.suite.id, + createdAt: uint64(time.Now().Unix()), + alpnProtocol: c.clientProtocol, + SNI: c.serverName, + maxEarlyDataLen: c.config.Max0RTTDataSize, + } + hash := hashForSuite(hs.suite) + + for i := 0; i < numSessionTickets; i++ { + if _, err := io.ReadFull(c.config.rand(), ageAddBuf); err != nil { + c.sendAlert(alertInternalError) + return err + } + sessionState.ageAdd = uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 | + uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]) + // ticketNonce must be a unique value for this connection. + // Assume there are no more than 255 tickets, otherwise two + // tickets might have the same PSK which could be a problem if + // one of them is compromised. + ticketNonce := []byte{byte(i)} + sessionState.pskSecret = hkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size()) + ticket := sessionState.marshal() + var err error + if c.config.SessionTicketSealer != nil { + cs := c.ConnectionState() + ticket, err = c.config.SessionTicketSealer.Seal(&cs, ticket) + } else { + ticket, err = c.encryptTicket(ticket) + } + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ticket == nil { + continue + } + ticketMsg := &newSessionTicketMsg13{ + lifetime: 24 * 3600, // TODO(filippo) + maxEarlyDataLength: c.config.Max0RTTDataSize, + withEarlyDataInfo: c.config.Max0RTTDataSize > 0, + ageAdd: sessionState.ageAdd, + nonce: ticketNonce, + ticket: ticket, + } + if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil { + return err + } + } + + return nil +} + +func (hs *serverHandshakeState) traceErr(err error) { + if err == nil { + return + } + if os.Getenv("TLSDEBUG") == "error" { + if hs != nil && hs.clientHello != nil { + os.Stderr.WriteString(hex.Dump(hs.clientHello.marshal())) + } else if err == io.EOF { + return // don't stack trace on EOF before CH + } + fmt.Fprintf(os.Stderr, "\n%s\n", debug.Stack()) + } + if os.Getenv("TLSDEBUG") == "short" { + var pcs [4]uintptr + frames := runtime.CallersFrames(pcs[0:runtime.Callers(3, pcs[:])]) + for { + frame, more := frames.Next() + if frame.Function != "crypto/tls.(*halfConn).setErrorLocked" && + frame.Function != "crypto/tls.(*Conn).sendAlertLocked" && + frame.Function != "crypto/tls.(*Conn).sendAlert" { + file := frame.File[strings.LastIndex(frame.File, "/")+1:] + log.Printf("%s:%d (%s): %v", file, frame.Line, frame.Function, err) + return + } + if !more { + break + } + } + } +} + +func getCertsFromEntries(certEntries []certificateEntry) [][]byte { + certs := make([][]byte, len(certEntries)) + for i, cert := range certEntries { + certs[i] = cert.data + } + return certs +} + +func (hs *clientHandshakeState) processEncryptedExtensions(ee *encryptedExtensionsMsg) error { + c := hs.c + if ee.alpnProtocol != "" { + c.clientProtocol = ee.alpnProtocol + c.clientProtocolFallback = false + } + if hs.c.config.ReceivedExtensions != nil { + return hs.c.config.ReceivedExtensions(typeEncryptedExtensions, ee.additionalExtensions) + } + return nil +} + +func verifyPeerHandshakeSignature( + certVerify *certificateVerifyMsg, + pubKey crypto.PublicKey, + signAlgosKnown []SignatureScheme, + transHash []byte, + contextString string) (error, alert) { + + _, sigType, hashFunc, err := pickSignatureAlgorithm( + pubKey, + []SignatureScheme{certVerify.signatureAlgorithm}, + signAlgosKnown, + VersionTLS13) + if err != nil { + return err, alertHandshakeFailure + } + + digest := prepareDigitallySigned(hashFunc, contextString, transHash) + err = verifyHandshakeSignature(sigType, pubKey, hashFunc, digest, certVerify.signature) + + if err != nil { + return err, alertDecryptError + } + + return nil, alertSuccess +} + +func (hs *clientHandshakeState) getCertificate13(certReq *certificateRequestMsg13) (*Certificate, error) { + certReq12 := &certificateRequestMsg{ + hasSignatureAndHash: true, + supportedSignatureAlgorithms: certReq.supportedSignatureAlgorithms, + certificateAuthorities: certReq.certificateAuthorities, + } + + var rsaAvail, ecdsaAvail bool + for _, sigAlg := range certReq.supportedSignatureAlgorithms { + switch signatureFromSignatureScheme(sigAlg) { + case signaturePKCS1v15, signatureRSAPSS: + rsaAvail = true + case signatureECDSA: + ecdsaAvail = true + } + } + if rsaAvail { + certReq12.certificateTypes = append(certReq12.certificateTypes, certTypeRSASign) + } + if ecdsaAvail { + certReq12.certificateTypes = append(certReq12.certificateTypes, certTypeECDSASign) + } + + return hs.getCertificate(certReq12) +} + +func (hs *clientHandshakeState) sendCertificate13(chainToSend *Certificate, certReq *certificateRequestMsg13) error { + c := hs.c + + certEntries := []certificateEntry{} + for _, cert := range chainToSend.Certificate { + certEntries = append(certEntries, certificateEntry{data: cert}) + } + certMsg := &certificateMsg13{certificates: certEntries} + + hs.keySchedule.write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + if len(certEntries) == 0 { + // No client cert available, nothing to sign. + return nil + } + + key, ok := chainToSend.PrivateKey.(crypto.Signer) + if !ok { + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) + } + + signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(key.Public(), certReq.supportedSignatureAlgorithms, hs.hello.supportedSignatureAlgorithms, c.vers) + if err != nil { + hs.c.sendAlert(alertHandshakeFailure) + return err + } + + digest := prepareDigitallySigned(hashFunc, "TLS 1.3, client CertificateVerify", hs.keySchedule.transcriptHash.Sum(nil)) + signOpts := crypto.SignerOpts(hashFunc) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunc} + } + signature, err := key.Sign(c.config.rand(), digest, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + verifyMsg := &certificateVerifyMsg{ + hasSignatureAndHash: true, + signatureAlgorithm: signatureAlgorithm, + signature: signature, + } + hs.keySchedule.write(verifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeState) doTLS13Handshake() error { + c := hs.c + hash := hashForSuite(hs.suite) + hashSize := hash.Size() + serverHello := hs.serverHello + c.scts = serverHello.scts + + // middlebox compatibility mode, send CCS before second flight. + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + // TODO check if keyshare is unacceptable, raise HRR. + + clientKS := hs.hello.keyShares[0] + if serverHello.keyShare.group != clientKS.group { + c.sendAlert(alertIllegalParameter) + return errors.New("bad or missing key share from server") + } + + // 0-RTT is not supported yet, so use an empty PSK. + hs.keySchedule.setSecret(nil) + ecdheSecret := deriveECDHESecret(serverHello.keyShare, hs.privateKey) + if ecdheSecret == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: bad ECDHE server share") + } + + // Calculate handshake secrets. + hs.keySchedule.setSecret(ecdheSecret) + clientHandshakeSecret := hs.keySchedule.deriveSecret(secretHandshakeClient) + if c.hand.Len() > 0 { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: unexpected data after Server Hello") + } + // Do not change the sender key yet, the server must authenticate first. + serverHandshakeSecret := hs.keySchedule.deriveSecret(secretHandshakeServer) + c.in.setKey(c.vers, hs.keySchedule.suite, serverHandshakeSecret) + + // Calculate MAC key for Finished messages. + serverFinishedKey := hkdfExpandLabel(hash, serverHandshakeSecret, nil, "finished", hashSize) + clientFinishedKey := hkdfExpandLabel(hash, clientHandshakeSecret, nil, "finished", hashSize) + + msg, err := c.readHandshake() + if err != nil { + return err + } + encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(encryptedExtensions, msg) + } + if err := hs.processEncryptedExtensions(encryptedExtensions); err != nil { + return err + } + hs.keySchedule.write(encryptedExtensions.marshal()) + + // PSKs are not supported, so receive Certificate message. + msg, err = c.readHandshake() + if err != nil { + return err + } + + var chainToSend *Certificate + certReq, isCertRequested := msg.(*certificateRequestMsg13) + if isCertRequested { + hs.keySchedule.write(certReq.marshal()) + + if chainToSend, err = hs.getCertificate13(certReq); err != nil { + c.sendAlert(alertInternalError) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + certMsg, ok := msg.(*certificateMsg13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.keySchedule.write(certMsg.marshal()) + + // Validate certificates. + certs := getCertsFromEntries(certMsg.certificates) + if err := hs.processCertsFromServer(certs); err != nil { + return err + } + + // Receive CertificateVerify message. + msg, err = c.readHandshake() + if err != nil { + return err + } + certVerifyMsg, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerifyMsg, msg) + } + + // Validate the DC if present. The DC is only processed if the extension was + // indicated by the ClientHello; otherwise this call will result in an + // "illegal_parameter" alert. + if len(certMsg.certificates) > 0 { + if err := hs.processDelegatedCredentialFromServer( + certMsg.certificates[0].delegatedCredential, + certVerifyMsg.signatureAlgorithm); err != nil { + return err + } + } + + // Set the public key used to verify the handshake. + pk := hs.c.peerCertificates[0].PublicKey + + // If the delegated credential extension has successfully been negotiated, + // then the CertificateVerify signature will have been produced with the + // DelegatedCredential's private key. + if hs.c.verifiedDc != nil { + pk = hs.c.verifiedDc.cred.publicKey + } + + // Verify the handshake signature. + err, alertCode := verifyPeerHandshakeSignature( + certVerifyMsg, + pk, + hs.hello.supportedSignatureAlgorithms, + hs.keySchedule.transcriptHash.Sum(nil), + "TLS 1.3, server CertificateVerify") + if err != nil { + c.sendAlert(alertCode) + return err + } + hs.keySchedule.write(certVerifyMsg.marshal()) + + // Receive Finished message. + msg, err = c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + // Validate server Finished hash. + expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, serverFinishedKey) + if subtle.ConstantTimeCompare(expectedVerifyData, serverFinished.verifyData) != 1 { + c.sendAlert(alertDecryptError) + return errors.New("tls: server's Finished message is incorrect") + } + hs.keySchedule.write(serverFinished.marshal()) + + // Server has authenticated itself. Calculate application traffic secrets. + hs.keySchedule.setSecret(nil) // derive master secret + + // Change outbound handshake cipher for final step + c.out.setKey(c.vers, hs.keySchedule.suite, clientHandshakeSecret) + + clientAppTrafficSecret := hs.keySchedule.deriveSecret(secretApplicationClient) + serverAppTrafficSecret := hs.keySchedule.deriveSecret(secretApplicationServer) + // TODO store initial traffic secret key for KeyUpdate GH #85 + + // Client auth requires sending a (possibly empty) Certificate followed + // by a CertificateVerify message (if there was an actual certificate). + if isCertRequested { + if err := hs.sendCertificate13(chainToSend, certReq); err != nil { + return err + } + } + + // Send Finished + verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, clientFinishedKey) + clientFinished := &finishedMsg{ + verifyData: verifyData, + } + if _, err := c.writeRecord(recordTypeHandshake, clientFinished.marshal()); err != nil { + return err + } + + // Handshake done, set application traffic secret + // TODO store initial traffic secret key for KeyUpdate GH #85 + c.out.setKey(c.vers, hs.keySchedule.suite, clientAppTrafficSecret) + if c.hand.Len() > 0 { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: unexpected data after handshake") + } + c.in.setKey(c.vers, hs.keySchedule.suite, serverAppTrafficSecret) + return nil +} + +// supportedSigAlgorithmsCert iterates over schemes and filters out those algorithms +// which are not supported for certificate verification. +func supportedSigAlgorithmsCert(schemes []SignatureScheme) (ret []SignatureScheme) { + for _, sig := range schemes { + // X509 doesn't support PSS signatures + if !signatureSchemeIsPSS(sig) { + ret = append(ret, sig) + } + } + return +} diff --git a/vendor/github.com/marten-seemann/qtls/README.md b/vendor/github.com/marten-seemann/qtls/README.md new file mode 100644 index 00000000..be5c08c1 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/README.md @@ -0,0 +1,107 @@ +``` + _____ _ ____ _ _ +|_ _| | / ___| | |_ _ __(_)___ + | | | | \___ \ _____| __| '__| / __| + | | | |___ ___) |_____| |_| | | \__ \ + |_| |_____|____/ \__|_| |_|___/ + +``` + +crypto/tls, now with 100% more 1.3. + +THE API IS NOT STABLE AND DOCUMENTATION IS NOT GUARANTEED. + +[![Build Status](https://travis-ci.org/cloudflare/tls-tris.svg?branch=master)](https://travis-ci.org/cloudflare/tls-tris) + +## Usage + +Since `crypto/tls` is very deeply (and not that elegantly) coupled with the Go stdlib, +tls-tris shouldn't be used as an external package. It is also impossible to vendor it +as `crypto/tls` because stdlib packages would import the standard one and mismatch. + +So, to build with tls-tris, you need to use a custom GOROOT. + +A script is provided that will take care of it for you: `./_dev/go.sh`. +Just use that instead of the `go` tool. + +The script also transparently fetches the custom Cloudflare Go 1.10 compiler with the required backports. + +## Development + +### Dependencies + +Copy paste line bellow to install all required dependencies: + +* ArchLinux: +``` +pacman -S go docker gcc git make patch python2 python-docker rsync +``` + +* Debian: +``` +apt-get install build-essential docker go patch python python-pip rsync +pip install setuptools +pip install docker +``` + +* Ubuntu (18.04) : +``` +apt-get update +apt-get install build-essential docker docker.io golang patch python python-pip rsync sudo +pip install setuptools +pip install docker +sudo usermod -a -G docker $USER +``` + +Similar dependencies can be found on any UNIX based system/distribution. + +### Building + +There are number of things that need to be setup before running tests. Most important step is to copy ``go env GOROOT`` directory to ``_dev`` and swap TLS implementation and recompile GO. Then for testing we use go implementation from ``_dev/GOROOT``. + +``` +git clone https://github.com/cloudflare/tls-tris.git +cd tls-tris; cp _dev/utils/pre-commit .git/hooks/ +make -f _dev/Makefile build-all +``` + +### Testing + +We run 3 kinds of test:. + +* Unit testing:
``make -f _dev/Makefile test-unit`` +* Testing against BoringSSL test suite:
``make -f _dev/Makefile test-bogo`` +* Compatibility testing (see below):
``make -f _dev/Makefile test-interop`` + +To run all the tests in one go use: +``` +make -f _dev/Makefile test +``` + +### Testing interoperability with 3rd party libraries + +In order to ensure compatibility we are testing our implementation against BoringSSL, NSS and PicoTLS. + +Makefile has a specific target for testing interoperability with external libraries. Following command can be used in order to run such test: + +``` +make -f _dev/Makefile test-interop +``` + +The makefile target is just a wrapper and it executes ``_dev/interop_test_runner`` script written in python. The script implements interoperability tests using ``python unittest`` framework. + +Script can be started from command line directly. For example: + +``` +> ./interop_test_runner -v InteropServer_NSS.test_zero_rtt +test_zero_rtt (__main__.InteropServer_NSS) ... ok + +---------------------------------------------------------------------- +Ran 1 test in 8.765s + +OK +``` + +### Debugging + +When the environment variable `TLSDEBUG` is set to `error`, Tris will print a hexdump of the Client Hello and a stack trace if an handshake error occurs. If the value is `short`, only the error and the first meaningful stack frame are printed. diff --git a/vendor/github.com/marten-seemann/qtls/alert.go b/vendor/github.com/marten-seemann/qtls/alert.go new file mode 100644 index 00000000..bfd552d5 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/alert.go @@ -0,0 +1,84 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import "strconv" + +type alert uint8 + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertInappropriateFallback alert = 86 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 + alertUnsupportedExtension alert = 110 + alertCertificateRequired alert = 116 + alertNoApplicationProtocol alert = 120 + alertSuccess alert = 255 // dummy value returned by unmarshal functions +) + +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertInappropriateFallback: "inappropriate fallback", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", + alertNoApplicationProtocol: "no application protocol", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return "tls: " + s + } + return "tls: alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/marten-seemann/qtls/auth.go b/vendor/github.com/marten-seemann/qtls/auth.go new file mode 100644 index 00000000..3e3b3fd2 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/auth.go @@ -0,0 +1,107 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "encoding/asn1" + "errors" + "fmt" +) + +// pickSignatureAlgorithm selects a signature algorithm that is compatible with +// the given public key and the list of algorithms from the peer and this side. +// +// The returned SignatureScheme codepoint is only meaningful for TLS 1.2, +// previous TLS versions have a fixed hash function. +func pickSignatureAlgorithm(pubkey crypto.PublicKey, peerSigAlgs, ourSigAlgs []SignatureScheme, tlsVersion uint16) (SignatureScheme, uint8, crypto.Hash, error) { + if tlsVersion < VersionTLS12 || len(peerSigAlgs) == 0 { + // If the client didn't specify any signature_algorithms + // extension then we can assume that it supports SHA1. See + // http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + switch pubkey.(type) { + case *rsa.PublicKey: + if tlsVersion < VersionTLS12 { + return 0, signaturePKCS1v15, crypto.MD5SHA1, nil + } else { + return PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1, nil + } + case *ecdsa.PublicKey: + return ECDSAWithSHA1, signatureECDSA, crypto.SHA1, nil + default: + return 0, 0, 0, fmt.Errorf("tls: unsupported public key: %T", pubkey) + } + } + for _, sigAlg := range peerSigAlgs { + if !isSupportedSignatureAlgorithm(sigAlg, ourSigAlgs) { + continue + } + hashAlg, err := lookupTLSHash(sigAlg) + if err != nil { + panic("tls: supported signature algorithm has an unknown hash function") + } + sigType := signatureFromSignatureScheme(sigAlg) + if (sigType == signaturePKCS1v15 || hashAlg == crypto.SHA1) && tlsVersion >= VersionTLS13 { + // TLS 1.3 forbids RSASSA-PKCS1-v1_5 and SHA-1 for + // handshake messages. + continue + } + switch pubkey.(type) { + case *rsa.PublicKey: + if sigType == signaturePKCS1v15 || sigType == signatureRSAPSS { + return sigAlg, sigType, hashAlg, nil + } + case *ecdsa.PublicKey: + if sigType == signatureECDSA { + return sigAlg, sigType, hashAlg, nil + } + } + } + return 0, 0, 0, errors.New("tls: peer doesn't support any common signature algorithms") +} + +// verifyHandshakeSignature verifies a signature against pre-hashed handshake +// contents. +func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, digest, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return errors.New("tls: ECDSA signing requires a ECDSA public key") + } + ecdsaSig := new(ecdsaSignature) + if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { + return err + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return errors.New("tls: ECDSA signature contained zero or negative values") + } + if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) { + return errors.New("tls: ECDSA verification failure") + } + case signaturePKCS1v15: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return errors.New("tls: RSA signing requires a RSA public key") + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, digest, sig); err != nil { + return err + } + case signatureRSAPSS: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return errors.New("tls: RSA signing requires a RSA public key") + } + signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + if err := rsa.VerifyPSS(pubKey, hashFunc, digest, sig, signOpts); err != nil { + return err + } + default: + return errors.New("tls: unknown signature algorithm") + } + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls/cipher_suites.go b/vendor/github.com/marten-seemann/qtls/cipher_suites.go new file mode 100644 index 00000000..3bbc0b90 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/cipher_suites.go @@ -0,0 +1,437 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/hmac" + "crypto/rc4" + "crypto/sha1" + "crypto/sha256" + "hash" + + "golang.org/x/crypto/chacha20poly1305" +) + +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*Config, crypto.PrivateKey, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*Config, crypto.PrivateKey, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, crypto.PublicKey, *serverKeyExchangeMsg) error + generateClientKeyExchange(*Config, *clientHelloMsg, crypto.PublicKey) ([]byte, *clientKeyExchangeMsg, error) +} + +const ( + // suiteECDH indicates that the cipher suite involves elliptic curve + // Diffie-Hellman. This means that it should only be selected when the + // client indicates that it supports ECC with a curve and point format + // that we're happy with. + suiteECDHE = 1 << iota + // suiteECDSA indicates that the cipher suite involves an ECDSA + // signature and therefore may only be selected when the server's + // certificate is ECDSA. If this is not set then the cipher suite is + // RSA based. + suiteECDSA + // suiteTLS12 indicates that the cipher suite should only be advertised + // and accepted when using TLS 1.2. + suiteTLS12 + // suiteTLS13 indicates that the ones and only cipher suites to be + // advertised and accepted when using TLS 1.3. + suiteTLS13 + // suiteSHA384 indicates that the cipher suite uses SHA384 as the + // handshake hash. + suiteSHA384 + // suiteDefaultOff indicates that this cipher suite is not included by + // default. + suiteDefaultOff +) + +// A cipherSuite is a specific combination of key agreement, cipher and MAC +// function. +type cipherSuite struct { + id uint16 + // the lengths, in bytes, of the key material needed for each component. + keyLen int + macLen int + ivLen int + ka func(version uint16) keyAgreement + // flags is a bitmask of the suite* values, above. + flags int + cipher func(key, iv []byte, isRead bool) interface{} + mac func(version uint16, macKey []byte) macFunction + aead func(key, fixedNonce []byte) cipher.AEAD +} + +type CipherSuite struct { + cipherSuite +} + +func (c *CipherSuite) Hash() crypto.Hash { return hashForSuite(&c.cipherSuite) } +func (c *CipherSuite) KeyLen() int { return c.keyLen } +func (c *CipherSuite) IVLen() int { return c.ivLen } +func (c *CipherSuite) AEAD(key, fixedNonce []byte) cipher.AEAD { return c.aead(key, fixedNonce) } + +var cipherSuites = []*cipherSuite{ + // TLS 1.3 ciphersuites specify only the AEAD and the HKDF hash. + {TLS_CHACHA20_POLY1305_SHA256, 32, 0, 12, nil, suiteTLS13, nil, nil, aeadChaCha20Poly1305}, + {TLS_AES_128_GCM_SHA256, 16, 0, 12, nil, suiteTLS13, nil, nil, aeadAESGCM13}, + {TLS_AES_256_GCM_SHA384, 32, 0, 12, nil, suiteTLS13 | suiteSHA384, nil, nil, aeadAESGCM13}, + + // Ciphersuite order is chosen so that ECDHE comes before plain RSA and + // AEADs are the top preference. + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM12}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12, nil, nil, aeadAESGCM12}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM12}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM12}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM12}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM12}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, + + // RC4-based cipher suites are disabled by default. + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteDefaultOff, cipherRC4, macSHA1, nil}, +} + +func cipherRC4(key, iv []byte, isRead bool) interface{} { + cipher, _ := rc4.NewCipher(key) + return cipher +} + +func cipher3DES(key, iv []byte, isRead bool) interface{} { + block, _ := des.NewTripleDESCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +func cipherAES(key, iv []byte, isRead bool) interface{} { + block, _ := aes.NewCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +// macSHA1 returns a macFunction for the given protocol version. +func macSHA1(version uint16, key []byte) macFunction { + if version == VersionSSL30 { + mac := ssl30MAC{ + h: sha1.New(), + key: make([]byte, len(key)), + } + copy(mac.key, key) + return mac + } + return tls10MAC{hmac.New(newConstantTimeHash(sha1.New), key)} +} + +// macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2 +// so the given version is ignored. +func macSHA256(version uint16, key []byte) macFunction { + return tls10MAC{hmac.New(sha256.New, key)} +} + +type macFunction interface { + Size() int + MAC(digestBuf, seq, header, data, extra []byte) []byte +} + +type aead interface { + cipher.AEAD + + // explicitIVLen returns the number of bytes used by the explicit nonce + // that is included in the record. This is eight for older AEADs and + // zero for modern ones. + explicitNonceLen() int +} + +// fixedNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +// each call. +type fixedNonceAEAD struct { + // nonce contains the fixed part of the nonce in the first four bytes. + nonce [12]byte + aead cipher.AEAD +} + +func (f *fixedNonceAEAD) NonceSize() int { return 8 } + +// Overhead returns the maximum difference between the lengths of a +// plaintext and its ciphertext. +func (f *fixedNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *fixedNonceAEAD) explicitNonceLen() int { return 8 } + +func (f *fixedNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + copy(f.nonce[4:], nonce) + return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) +} + +func (f *fixedNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byte, error) { + copy(f.nonce[4:], nonce) + return f.aead.Open(out, f.nonce[:], plaintext, additionalData) +} + +// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// before each call. +type xorNonceAEAD struct { + nonceMask [12]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result, err +} + +func aeadAESGCM12(key, fixedNonce []byte) cipher.AEAD { + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &fixedNonceAEAD{aead: aead} + copy(ret.nonce[:], fixedNonce) + return ret +} + +func aeadAESGCM13(key, fixedNonce []byte) cipher.AEAD { + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], fixedNonce) + return ret +} + +func aeadChaCha20Poly1305(key, fixedNonce []byte) cipher.AEAD { + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], fixedNonce) + return ret +} + +// ssl30MAC implements the SSLv3 MAC function, as defined in +// www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 5.2.3.1 +type ssl30MAC struct { + h hash.Hash + key []byte +} + +func (s ssl30MAC) Size() int { + return s.h.Size() +} + +var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36} + +var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c} + +// MAC does not offer constant timing guarantees for SSL v3.0, since it's deemed +// useless considering the similar, protocol-level POODLE vulnerability. +func (s ssl30MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte { + padLength := 48 + if s.h.Size() == 20 { + padLength = 40 + } + + s.h.Reset() + s.h.Write(s.key) + s.h.Write(ssl30Pad1[:padLength]) + s.h.Write(seq) + s.h.Write(header[:1]) + s.h.Write(header[3:5]) + s.h.Write(data) + digestBuf = s.h.Sum(digestBuf[:0]) + + s.h.Reset() + s.h.Write(s.key) + s.h.Write(ssl30Pad2[:padLength]) + s.h.Write(digestBuf) + return s.h.Sum(digestBuf[:0]) +} + +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + +// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3. +type tls10MAC struct { + h hash.Hash +} + +func (s tls10MAC) Size() int { + return s.h.Size() +} + +// MAC is guaranteed to take constant time, as long as +// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into +// the MAC, but is only provided to make the timing profile constant. +func (s tls10MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte { + s.h.Reset() + s.h.Write(seq) + s.h.Write(header) + s.h.Write(data) + res := s.h.Sum(digestBuf[:0]) + if extra != nil { + s.h.Write(extra) + } + return res +} + +func rsaKA(version uint16) keyAgreement { + return rsaKeyAgreement{} +} + +func ecdheECDSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: false, + version: version, + } +} + +func ecdheRSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: true, + version: version, + } +} + +// mutualCipherSuite returns a cipherSuite given a list of supported +// ciphersuites and the id requested by the peer. +func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { + for _, id := range have { + if id == want { + for _, suite := range cipherSuites { + if suite.id == want { + return suite + } + } + return nil + } + } + return nil +} + +// A list of cipher suite IDs that are, or have been, implemented by this +// package. +// +// Taken from http://www.iana.org/assignments/tls-parameters/tls-parameters.xml +const ( + // TLS 1.0 - 1.2 cipher suites. + TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 uint16 = 0xcca8 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 uint16 = 0xcca9 + + // TLS 1.3+ cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = 0x1301 + TLS_AES_256_GCM_SHA384 uint16 = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 + + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator + // that the client is doing version fallback. See + // https://tools.ietf.org/html/rfc7507. + TLS_FALLBACK_SCSV uint16 = 0x5600 +) diff --git a/vendor/github.com/marten-seemann/qtls/common.go b/vendor/github.com/marten-seemann/qtls/common.go new file mode 100644 index 00000000..76ce5e44 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/common.go @@ -0,0 +1,1215 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "container/list" + "crypto" + "crypto/rand" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "math/big" + "net" + "strings" + "sync" + "time" +) + +const ( + VersionSSL30 = 0x0300 + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 + VersionTLS13 = 0x0304 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxWarnAlertCount = 5 // maximum number of consecutive warning alerts + + minVersion = VersionTLS12 + maxVersion = VersionTLS13 +) + +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeNextProtocol uint8 = 67 // Not IANA assigned +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +type Extension struct { + Type uint16 + Data []byte +} + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // Supported Groups in 1.3 nomenclature + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 // https://tools.ietf.org/html/rfc6962#section-6 + extensionEMS uint16 = 23 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionPSKKeyExchangeModes uint16 = 45 + extensionCAs uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionNextProtoNeg uint16 = 13172 // not IANA assigned + extensionRenegotiationInfo uint16 = 0xff01 + extensionDelegatedCredential uint16 = 0xff02 // TODO(any) Get IANA assignment +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +// PSK Key Exchange Modes +// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.7 +const ( + pskDHEKeyExchange uint8 = 1 +) + +// CurveID is tls.CurveID +// TLS 1.3 refers to these as Groups, but this library implements only +// curve-based ones anyway. See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4. +type CurveID = tls.CurveID + +const ( + // Exported IDs + CurveP256 = tls.CurveP256 + CurveP384 = tls.CurveP384 + CurveP521 = tls.CurveP521 + X25519 = tls.X25519 +) + +// TLS 1.3 Key Share +// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5 +type keyShare struct { + group CurveID + data []byte +} + +// TLS 1.3 PSK Identity and Binder, as sent by the client +// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6 + +type psk struct { + identity []byte + obfTicketAge uint32 + binder []byte +} + +// TLS Elliptic Curve Point Formats +// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + certTypeRSASign = 1 // A certificate containing an RSA key + certTypeDSSSign = 2 // A certificate containing a DSA key + certTypeRSAFixedDH = 3 // A certificate containing a static DH key + certTypeDSSFixedDH = 4 // A certificate containing a static DH key + + // See RFC 4492 sections 3 and 5.5. + certTypeECDSASign = 64 // A certificate containing an ECDSA-capable public key, signed with ECDSA. + certTypeRSAFixedECDH = 65 // A certificate containing an ECDH-capable public key, signed with RSA. + certTypeECDSAFixedECDH = 66 // A certificate containing an ECDH-capable public key, signed with ECDSA. + + // Rest of these are reserved by the TLS spec +) + +// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1) +const ( + signaturePKCS1v15 uint8 = iota + 1 + signatureECDSA + signatureRSAPSS +) + +// supportedSignatureAlgorithms contains the signature and hash algorithms that +// the code advertises as supported in a TLS 1.2 ClientHello and in a TLS 1.2 +// CertificateRequest. The two fields are merged to match with TLS 1.3. +// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. +var supportedSignatureAlgorithms = []SignatureScheme{ + PKCS1WithSHA256, + ECDSAWithP256AndSHA256, + PKCS1WithSHA384, + ECDSAWithP384AndSHA384, + PKCS1WithSHA512, + ECDSAWithP521AndSHA512, + PKCS1WithSHA1, + ECDSAWithSHA1, +} + +// supportedSignatureAlgorithms13 lists the advertised signature algorithms +// allowed for digital signatures. It includes TLS 1.2 + PSS. +var supportedSignatureAlgorithms13 = []SignatureScheme{ + PSSWithSHA256, + PKCS1WithSHA256, + ECDSAWithP256AndSHA256, + PSSWithSHA384, + PKCS1WithSHA384, + ECDSAWithP384AndSHA384, + PSSWithSHA512, + PKCS1WithSHA512, + ECDSAWithP521AndSHA512, + PKCS1WithSHA1, + ECDSAWithSHA1, +} + +// ConnectionState records basic TLS details about the connection. +type ConnectionState struct { + ConnectionID []byte // Random unique connection id + Version uint16 // TLS version used by the connection (e.g. VersionTLS12) + HandshakeComplete bool // TLS handshake is complete + DidResume bool // connection resumes a previous TLS connection + CipherSuite uint16 // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) + NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos) + NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only) + ServerName string // server name requested by client, if any (server side only) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer + VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates + SignedCertificateTimestamps [][]byte // SCTs from the server, if any + OCSPResponse []byte // stapled OCSP response from server, if any + DelegatedCredential []byte // Delegated credential sent by the server, if any + + // TLSUnique contains the "tls-unique" channel binding value (see RFC + // 5929, section 3). For resumed sessions this value will be nil + // because resumption does not include enough context (see + // https://mitls.org/pages/attacks/3SHAKE#channelbindings). This will + // change in future versions of Go once the TLS master-secret fix has + // been standardized and implemented. + TLSUnique []byte + + // HandshakeConfirmed is true once all data returned by Read + // (past and future) is guaranteed not to be replayed. + HandshakeConfirmed bool + + // Unique0RTTToken is a value that never repeats, and can be used + // to detect replay attacks against 0-RTT connections. + // Unique0RTTToken is only present if HandshakeConfirmed is false. + Unique0RTTToken []byte + + ClientHello []byte // ClientHello packet +} + +// The ClientAuthType is the tls.ClientAuthType +type ClientAuthType = tls.ClientAuthType + +const ( + NoClientCert = tls.NoClientCert + RequestClientCert = tls.RequestClientCert + RequireAnyClientCert = tls.RequireAnyClientCert + VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven + RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert +) + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // SSL/TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // MasterSecret generated by client on a full handshake + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification + useEMS bool // State of extended master secret +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. Only ticket-based resumption is supported, not SessionID-based +// resumption. +type ClientSessionCache interface { + // Get searches for a ClientSessionState associated with the given key. + // On return, ok is true if one was found. + Get(sessionKey string) (session *ClientSessionState, ok bool) + + // Put adds the ClientSessionState to the cache with the given key. + Put(sessionKey string, cs *ClientSessionState) +} + +// SignatureScheme is a tls.SignatureScheme +type SignatureScheme = tls.SignatureScheme + +const ( + PKCS1WithSHA1 = tls.PKCS1WithSHA1 + PKCS1WithSHA256 = tls.PKCS1WithSHA256 + PKCS1WithSHA384 = tls.PKCS1WithSHA384 + PKCS1WithSHA512 = tls.PKCS1WithSHA512 + + PSSWithSHA256 = tls.PSSWithSHA256 + PSSWithSHA384 = tls.PSSWithSHA384 + PSSWithSHA512 = tls.PSSWithSHA512 + + ECDSAWithP256AndSHA256 = tls.ECDSAWithP256AndSHA256 + ECDSAWithP384AndSHA384 = tls.ECDSAWithP384AndSHA384 + ECDSAWithP521AndSHA512 = tls.ECDSAWithP521AndSHA512 + + // Legacy signature and hash algorithms for TLS 1.2. + ECDSAWithSHA1 = tls.ECDSAWithSHA1 +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide certificate selection in the GetCertificate callback. +type ClientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_RSA_WITH_RC4_128_SHA). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see + // http://tools.ietf.org/html/rfc4366#section-3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see + // http://tools.ietf.org/html/rfc4492#section-5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see + // http://tools.ietf.org/html/rfc4492#section-5.1.2). + SupportedPoints []uint8 + + // SignatureSchemes lists the signature and hash schemes that the client + // is willing to verify. SignatureSchemes is set only if the Signature + // Algorithms Extension is being used (see + // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1). + SignatureSchemes []SignatureScheme + + // SupportedProtos lists the application protocols supported by the client. + // SupportedProtos is set only if the Application-Layer Protocol + // Negotiation Extension is being used (see + // https://tools.ietf.org/html/rfc7301#section-3.1). + // + // Servers can select a protocol by setting Config.NextProtos in a + // GetConfigForClient return value. + SupportedProtos []string + + // SupportedVersions lists the TLS versions supported by the client. + // For TLS versions less than 1.3, this is extrapolated from the max + // version advertised by the client, so values other than the greatest + // might be rejected if used. + SupportedVersions []uint16 + + // Conn is the underlying net.Conn for the connection. Do not read + // from, or write to, this connection; that will cause the TLS + // connection to fail. + Conn net.Conn + + // Offered0RTTData is true if the client announced that it will send + // 0-RTT data. If the server Config.Accept0RTTData is true, and the + // client offered a session ticket valid for that purpose, it will + // be notified that the 0-RTT data is accepted and it will be made + // immediately available for Read. + Offered0RTTData bool + + // AcceptsDelegatedCredential is true if the client indicated willingness + // to negotiate the delegated credential extension. + AcceptsDelegatedCredential bool + + // The Fingerprint is an sequence of bytes unique to this Client Hello. + // It can be used to prevent or mitigate 0-RTT data replays as it's + // guaranteed that a replayed connection will have the same Fingerprint. + Fingerprint []byte +} + +// The CertificateRequestInfo is a tls.CertificateRequestInfo +type CertificateRequestInfo = tls.CertificateRequestInfo + +// RenegotiationSupport is a tls.RenegotiationSupport +type RenegotiationSupport = tls.RenegotiationSupport + +const ( + // RenegotiateNever disables renegotiation. + RenegotiateNever = tls.RenegotiateNever + + // RenegotiateOnceAsClient allows a remote server to request + // renegotiation once per connection. + RenegotiateOnceAsClient = tls.RenegotiateOnceAsClient + + // RenegotiateFreelyAsClient allows a remote server to repeatedly + // request renegotiation. + RenegotiateFreelyAsClient = tls.RenegotiateFreelyAsClient +) + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains to present to + // the other side of the connection. Server configurations must include + // at least one certificate or else set GetCertificate. Clients doing + // client-authentication may set either Certificates or + // GetClientCertificate. + Certificates []Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // See Config.BuildNameToCertificate + // The nil value causes the first element of Certificates to be used + // for all connections. + NameToCertificate map[string]*Certificate + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // first element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // Uniquely for the fields in the returned Config, session ticket keys + // will be duplicated from the original Config if not set. + // Specifically, if SetSessionTicketKeys was called on the original + // config but not on the returned config then the ticket keys from the + // original config will be copied into the new config before use. + // Otherwise, if SessionTicketKey was set in the original config but + // not in the returned config then it will be copied into the returned + // config before use. If neither of those cases applies then the key + // material from the returned config will be used for session tickets. + GetConfigForClient func(*ClientHelloInfo) (*Config, error) + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported, application level protocols. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting unless it is + // an IP address. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // InsecureSkipVerify controls whether a client verifies the + // server's certificate chain and host name. + // If InsecureSkipVerify is true, TLS accepts any certificate + // presented by the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + InsecureSkipVerify bool + + // CipherSuites is a list of supported cipher suites to be used in + // TLS 1.0-1.2. If CipherSuites is nil, TLS uses a list of suites + // supported by the implementation. + CipherSuites []uint16 + + // PreferServerCipherSuites controls whether the server selects the + // client's most preferred ciphersuite, or the server's most preferred + // ciphersuite. If true then the server's preference, as expressed in + // the order of elements in CipherSuites, is used. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket + // (resumption) support. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session + // resumption. See RFC 5077. If zero, it will be filled with + // random data before the first server handshake. + // + // If multiple servers are terminating connections for the same host + // they should all have the same SessionTicketKey. If the + // SessionTicketKey leaks, previously recorded and future TLS + // connections using that key are compromised. + SessionTicketKey [32]byte + + // ClientSessionCache is a cache of ClientSessionState entries for TLS + // session resumption. + ClientSessionCache ClientSessionCache + + // MinVersion contains the minimum SSL/TLS version that is acceptable. + // If zero, then TLS 1.0 is taken as the minimum. + MinVersion uint16 + + // MaxVersion contains the maximum SSL/TLS version that is acceptable. + // If zero, then the maximum version supported by this package is used, + // which is currently TLS 1.2. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. + CurvePreferences []CurveID + + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. + // When true, the largest possible TLS record size is always used. When + // false, the size of TLS records may be adjusted in an attempt to + // improve latency. + DynamicRecordSizingDisabled bool + + // Renegotiation controls what types of renegotiation are supported. + // The default, none, is correct for the vast majority of applications. + Renegotiation RenegotiationSupport + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // If Max0RTTDataSize is not zero, the client will be allowed to use + // session tickets to send at most this number of bytes of 0-RTT data. + // 0-RTT data is subject to replay and has memory DoS implications. + // The server will later be able to refuse the 0-RTT data with + // Accept0RTTData, or wait for the client to prove that it's not + // replayed with Conn.ConfirmHandshake. + // + // It has no meaning on the client. + // + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3. + Max0RTTDataSize uint32 + + // Accept0RTTData makes the 0-RTT data received from the client + // immediately available to Read. 0-RTT data is subject to replay. + // Use Conn.ConfirmHandshake to wait until the data is known not + // to be replayed after reading it. + // + // It has no meaning on the client. + // + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3. + Accept0RTTData bool + + // SessionTicketSealer, if not nil, is used to wrap and unwrap + // session tickets, instead of SessionTicketKey. + SessionTicketSealer SessionTicketSealer + + // AcceptDelegatedCredential is true if the client is willing to negotiate + // the delegated credential extension. + // + // This value has no meaning for the server. + // + // See https://tools.ietf.org/html/draft-ietf-tls-subcerts-02. + AcceptDelegatedCredential bool + + // GetDelegatedCredential returns a DC and its private key for use in the + // delegated credential extension. The inputs to the callback are some + // information parsed from the ClientHello, as well as the protocol version + // selected by the server. This is necessary because the DC is bound to the + // protocol version in which it's used. The return value is the raw DC + // encoded in the wire format specified in + // https://tools.ietf.org/html/draft-ietf-tls-subcerts-02. If the return + // value is nil, then the server will not offer negotiate the extension. + // + // This value has no meaning for the client. + GetDelegatedCredential func(*ClientHelloInfo, uint16) ([]byte, crypto.PrivateKey, error) + + // GetExtensions, if not nil, is called before a message that allows + // sending of extensions is sent. + // Currently only implemented for the ClientHello message (for the client) + // and for the EncryptedExtensions message (for the server). + // Only valid for TLS 1.3. + GetExtensions func(handshakeMessageType uint8) []Extension + + // ReceivedExtensions, if not nil, is called when a message that allows the + // inclusion of extensions is received. + // It is called with an empty slice of extensions, if the message didn't + // contain any extensions. + // Currently only implemented for the ClientHello message (sent by the + // client) and for the EncryptedExtensions message (sent by the server). + // Only valid for TLS 1.3. + ReceivedExtensions func(handshakeMessageType uint8, exts []Extension) error + + serverInitOnce sync.Once // guards calling (*Config).serverInit + + // mutex protects sessionTicketKeys. + mutex sync.RWMutex + // sessionTicketKeys contains zero or more ticket keys. If the length + // is zero, SessionTicketsDisabled must be true. The first key is used + // for new tickets and any subsequent keys can be used to decrypt old + // tickets. + sessionTicketKeys []ticketKey + + // UseExtendedMasterSecret indicates whether or not the connection + // should use the extended master secret computation if available + UseExtendedMasterSecret bool + + // AlternativeRecordLayer is used by QUIC + AlternativeRecordLayer RecordLayer +} + +type RecordLayer interface { + SetReadKey(suite *CipherSuite, trafficSecret []byte) + SetWriteKey(suite *CipherSuite, trafficSecret []byte) + ReadHandshakeMessage() ([]byte, error) + WriteRecord([]byte) (int, error) +} + +// ticketKeyNameLen is the number of bytes of identifier that is prepended to +// an encrypted session ticket in order to identify the key used to encrypt it. +const ticketKeyNameLen = 16 + +// ticketKey is the internal representation of a session ticket key. +type ticketKey struct { + // keyName is an opaque byte string that serves to identify the session + // ticket key. It's exposed as plaintext in every session ticket. + keyName [ticketKeyNameLen]byte + aesKey [16]byte + hmacKey [16]byte +} + +// ticketKeyFromBytes converts from the external representation of a session +// ticket key to a ticketKey. Externally, session ticket keys are 32 random +// bytes and this function expands that into sufficient name and key material. +func ticketKeyFromBytes(b [32]byte) (key ticketKey) { + hashed := sha512.Sum512(b[:]) + copy(key.keyName[:], hashed[:ticketKeyNameLen]) + copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) + copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) + return key +} + +// Clone returns a shallow clone of c. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *Config) Clone() *Config { + // Running serverInit ensures that it's safe to read + // SessionTicketsDisabled. + c.serverInitOnce.Do(func() { c.serverInit(nil) }) + + var sessionTicketKeys []ticketKey + c.mutex.RLock() + sessionTicketKeys = c.sessionTicketKeys + c.mutex.RUnlock() + + return &Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + Accept0RTTData: c.Accept0RTTData, + Max0RTTDataSize: c.Max0RTTDataSize, + SessionTicketSealer: c.SessionTicketSealer, + AcceptDelegatedCredential: c.AcceptDelegatedCredential, + GetDelegatedCredential: c.GetDelegatedCredential, + GetExtensions: c.GetExtensions, + ReceivedExtensions: c.ReceivedExtensions, + sessionTicketKeys: sessionTicketKeys, + UseExtendedMasterSecret: c.UseExtendedMasterSecret, + } +} + +// serverInit is run under c.serverInitOnce to do initialization of c. If c was +// returned by a GetConfigForClient callback then the argument should be the +// Config that was passed to Server, otherwise it should be nil. +func (c *Config) serverInit(originalConfig *Config) { + if c.SessionTicketsDisabled || len(c.ticketKeys()) != 0 || c.SessionTicketSealer != nil { + return + } + + alreadySet := false + for _, b := range c.SessionTicketKey { + if b != 0 { + alreadySet = true + break + } + } + + if !alreadySet { + if originalConfig != nil { + copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:]) + } else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + c.SessionTicketsDisabled = true + return + } + } + + if originalConfig != nil { + originalConfig.mutex.RLock() + c.sessionTicketKeys = originalConfig.sessionTicketKeys + originalConfig.mutex.RUnlock() + } else { + c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)} + } +} + +func (c *Config) ticketKeys() []ticketKey { + c.mutex.RLock() + // c.sessionTicketKeys is constant once created. SetSessionTicketKeys + // will only update it by replacing it with a new value. + ret := c.sessionTicketKeys + c.mutex.RUnlock() + return ret +} + +// SetSessionTicketKeys updates the session ticket keys for a server. The first +// key will be used when creating new tickets, while all keys can be used for +// decrypting tickets. It is safe to call this function while the server is +// running in order to rotate the session ticket keys. The function will panic +// if keys is empty. +func (c *Config) SetSessionTicketKeys(keys [][32]byte) { + if len(keys) == 0 { + panic("tls: keys must have at least one key") + } + + newKeys := make([]ticketKey, len(keys)) + for i, bytes := range keys { + newKeys[i] = ticketKeyFromBytes(bytes) + } + + c.mutex.Lock() + c.sessionTicketKeys = newKeys + c.mutex.Unlock() +} + +func (c *Config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *Config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func hasOverlappingCipherSuites(cs1, cs2 []uint16) bool { + for _, c1 := range cs1 { + for _, c2 := range cs2 { + if c1 == c2 { + return true + } + } + } + return false +} + +func (c *Config) cipherSuites() []uint16 { + s := c.CipherSuites + if s == nil { + s = defaultCipherSuites() + } else if c.maxVersion() >= VersionTLS13 { + // Ensure that TLS 1.3 suites are always present, but respect + // the application cipher suite preferences. + s13 := defaultTLS13CipherSuites() + if !hasOverlappingCipherSuites(s, s13) { + allSuites := make([]uint16, len(s13)+len(s)) + allSuites = append(allSuites, s13...) + s = append(allSuites, s...) + } + } + return s +} + +func (c *Config) minVersion() uint16 { + if c == nil || c.MinVersion == 0 { + return minVersion + } + return c.MinVersion +} + +func (c *Config) maxVersion() uint16 { + if c == nil || c.MaxVersion == 0 { + return maxVersion + } + return c.MaxVersion +} + +var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *Config) curvePreferences() []CurveID { + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +// mutualVersion returns the protocol version to use given the advertised +// version of the peer using the legacy non-extension methods. +func (c *Config) mutualVersion(vers uint16) (uint16, bool) { + minVersion := c.minVersion() + maxVersion := c.maxVersion() + + // Version 1.3 and higher are not negotiated via this mechanism. + if maxVersion > VersionTLS12 { + maxVersion = VersionTLS12 + } + + if vers < minVersion { + return 0, false + } + if vers > maxVersion { + vers = maxVersion + } + return vers, true +} + +// pickVersion returns the protocol version to use given the advertised +// versions of the peer using the Supported Versions extension. +func (c *Config) pickVersion(peerSupportedVersions []uint16) (uint16, bool) { + supportedVersions := c.getSupportedVersions() + for _, supportedVersion := range supportedVersions { + for _, version := range peerSupportedVersions { + if version == supportedVersion { + return version, true + } + } + } + return 0, false +} + +// configSuppVersArray is the backing array of Config.getSupportedVersions +var configSuppVersArray = [...]uint16{VersionTLS13, VersionTLS12, VersionTLS11, VersionTLS10, VersionSSL30} + +// getSupportedVersions returns the protocol versions that are supported by the +// current configuration. +func (c *Config) getSupportedVersions() []uint16 { + minVersion := c.minVersion() + maxVersion := c.maxVersion() + // Sanity check to avoid advertising unsupported versions. + if minVersion < VersionSSL30 { + minVersion = VersionSSL30 + } + if maxVersion > VersionTLS13 { + maxVersion = VersionTLS13 + } + if maxVersion < minVersion { + return nil + } + return configSuppVersArray[VersionTLS13-maxVersion : VersionTLS13-minVersion+1] +} + +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } + } + + if len(c.Certificates) == 0 { + return nil, errors.New("tls: no certificates configured") + } + + if len(c.Certificates) == 1 || c.NameToCertificate == nil { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + name := strings.ToLower(clientHello.ServerName) + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + + if cert, ok := c.NameToCertificate[name]; ok { + return cert, nil + } + + // try replacing labels in the name with wildcards until we get a + // match. + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[candidate]; ok { + return cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0], nil +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +func (c *Config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + continue + } + if len(x509Cert.Subject.CommonName) > 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +// writeKeyLog logs client random and master secret if logging was enabled by +// setting c.KeyLogWriter. +func (c *Config) writeKeyLog(what string, clientRandom, masterSecret []byte) error { + if c.KeyLogWriter == nil { + return nil + } + + logLine := []byte(fmt.Sprintf("%s %x %x\n", what, clientRandom, masterSecret)) + + writerMutex.Lock() + _, err := c.KeyLogWriter.Write(logLine) + writerMutex.Unlock() + + return err +} + +// writerMutex protects all KeyLogWriters globally. It is rarely enabled, +// and is only for debugging, so a global mutex saves space. +var writerMutex sync.Mutex + +// A Certificate is a tls.Certificate +type Certificate = tls.Certificate + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) alert +} + +// lruSessionCache is a ClientSessionCache implementation that uses an LRU +// caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state *ClientSessionState +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + } +} + +// Put adds the provided (sessionKey, cs) pair to the cache. +func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the ClientSessionState value associated with a given key. It +// returns (nil, false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +// TODO(jsing): Make these available to both crypto/x509 and crypto/tls. +type dsaSignature struct { + R, S *big.Int +} + +type ecdsaSignature dsaSignature + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +var ( + once sync.Once + varDefaultCipherSuites []uint16 + varDefaultTLS13CipherSuites []uint16 +) + +func defaultCipherSuites() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuites +} + +func defaultTLS13CipherSuites() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultTLS13CipherSuites +} + +func initDefaultCipherSuites() { + var topCipherSuites, topTLS13CipherSuites []uint16 + // TODO: check for hardware support + // This used to be: if cipherhw.AESGCMSupport() { + // However, cipherhw is an internal package + if true { + // If AES-GCM hardware is provided then prioritise AES-GCM + // cipher suites. + topTLS13CipherSuites = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, + } + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topTLS13CipherSuites = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + } + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + } + + varDefaultTLS13CipherSuites = make([]uint16, 0, len(cipherSuites)) + varDefaultTLS13CipherSuites = append(varDefaultTLS13CipherSuites, topTLS13CipherSuites...) + varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites)) + varDefaultCipherSuites = append(varDefaultCipherSuites, topCipherSuites...) + +NextCipherSuite: + for _, suite := range cipherSuites { + if suite.flags&suiteDefaultOff != 0 { + continue + } + if suite.flags&suiteTLS13 != 0 { + for _, existing := range varDefaultTLS13CipherSuites { + if existing == suite.id { + continue NextCipherSuite + } + } + varDefaultTLS13CipherSuites = append(varDefaultTLS13CipherSuites, suite.id) + } else { + for _, existing := range varDefaultCipherSuites { + if existing == suite.id { + continue NextCipherSuite + } + } + varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id) + } + } + varDefaultCipherSuites = append(varDefaultTLS13CipherSuites, varDefaultCipherSuites...) +} + +func unexpectedMessageError(wanted, got interface{}) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { + for _, s := range supportedSignatureAlgorithms { + if s == sigAlg { + return true + } + } + return false +} + +// signatureFromSignatureScheme maps a signature algorithm to the underlying +// signature method (without hash function). +func signatureFromSignatureScheme(signatureAlgorithm SignatureScheme) uint8 { + switch signatureAlgorithm { + case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: + return signaturePKCS1v15 + case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: + return signatureRSAPSS + case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: + return signatureECDSA + default: + return 0 + } +} + +// TODO(kk): Use variable length encoding? +func getUint24(b []byte) int { + n := int(b[2]) + n += int(b[1] << 8) + n += int(b[0] << 16) + return n +} + +func putUint24(b []byte, n int) { + b[0] = byte(n >> 16) + b[1] = byte(n >> 8) + b[2] = byte(n & 0xff) +} diff --git a/vendor/github.com/marten-seemann/qtls/conn.go b/vendor/github.com/marten-seemann/qtls/conn.go new file mode 100644 index 00000000..27761e50 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/conn.go @@ -0,0 +1,1766 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package qtls + +import ( + "bytes" + "crypto/cipher" + "crypto/subtle" + "crypto/x509" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn net.Conn + isClient bool + + phase handshakeStatus // protected by in.Mutex + // handshakeConfirmed is an atomic bool for phase == handshakeConfirmed + handshakeConfirmed int32 + // confirmMutex is held by any read operation before handshakeConfirmed + confirmMutex sync.Mutex + + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex + handshakeErr error // error resulting from handshake + connID []byte // Random connection id + clientHello []byte // ClientHello packet contents + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *Config // configuration passed to constructor + // handshakeComplete is true if the connection reached application data + // and it's equivalent to phase > handshakeRunning + handshakeComplete bool + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + handshakes int + didResume bool // whether this connection was a session resumption + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // Signed certificate timestamps from server + peerCertificates []*x509.Certificate + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // verifiedDc is set by a client who negotiates the use of a valid delegated + // credential. + verifiedDc *delegatedCredential + // serverName contains the server name indicated by the client, if any. + serverName string + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + // indicates wether extended MasterSecret extension is used (see RFC7627) + useEMS bool + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte + + clientProtocol string + clientProtocolFallback bool + + // ticketMaxEarlyData is the maximum bytes of 0-RTT application data + // that the client is allowed to send on the ticket it used. + ticketMaxEarlyData int64 + + // input/output + in, out halfConn // in.Mutex < out.Mutex + rawInput *block // raw input, right off the wire + input *block // application data waiting to be read + hand bytes.Buffer // handshake data waiting to be read + buffering bool // whether records are buffered in sendBuf + sendBuf []byte // a buffer of records waiting to be sent + + // bytesSent counts the bytes of application data sent. + // packetsSent counts packets. + bytesSent int64 + packetsSent int64 + + // warnCount counts the number of consecutive warning alerts received + // by Conn.readRecord. Protected by in.Mutex. + warnCount int + + // activeCall is an atomic int32; the low bit is whether Close has + // been called. the rest of the bits are the number of goroutines + // in Conn.Write. + activeCall int32 + + // TLS 1.3 needs the server state until it reaches the Client Finished + hs *serverHandshakeState + + // earlyDataBytes is the number of bytes of early data received so + // far. Tracked to enforce max_early_data_size. + // We don't keep track of rejected 0-RTT data since there's no need + // to ever buffer it. in.Mutex. + earlyDataBytes int64 + + // binder is the value of the PSK binder that was validated to + // accept the 0-RTT data. Exposed as ConnectionState.Unique0RTTToken. + binder []byte + + tmp [16]byte +} + +type handshakeStatus int + +const ( + handshakeRunning handshakeStatus = iota + discardingEarlyData + readingEarlyData + waitingClientFinished + readingClientFinished + handshakeConfirmed +) + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac macFunction + seq [8]byte // 64-bit sequence number + bfree *block // list of free blocks + additionalData [13]byte // to avoid allocs; interface method args escape + + nextCipher interface{} // next encryption state + nextMac macFunction // next MAC algorithm + + // used to save allocating a new buffer for each MAC. + inDigestBuf, outDigestBuf []byte + + setKeyCallback func(suite *CipherSuite, trafficSecret []byte) + + traceErr func(error) +} + +func (hc *halfConn) setErrorLocked(err error) error { + hc.err = err + if hc.traceErr != nil { + hc.traceErr(err) + } + return err +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) { + hc.version = version + hc.nextCipher = cipher + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec() error { + if hc.nextCipher == nil { + return alertInternalError + } + hc.cipher = hc.nextCipher + hc.mac = hc.nextMac + hc.nextCipher = nil + hc.nextMac = nil + for i := range hc.seq { + hc.seq[i] = 0 + } + return nil +} + +func (hc *halfConn) setKey(version uint16, suite *cipherSuite, trafficSecret []byte) { + if hc.setKeyCallback != nil { + hc.setKeyCallback(&CipherSuite{*suite}, trafficSecret) + return + } + hc.version = version + hash := hashForSuite(suite) + key := hkdfExpandLabel(hash, trafficSecret, nil, "key", suite.keyLen) + iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", suite.ivLen) + hc.cipher = suite.aead(key, iv) + for i := range hc.seq { + hc.seq[i] = 0 + } +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq() { + for i := 7; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 +func extractPadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) + + // The maximum possible padding length plus the actual length field + toCheck := 256 + // The length of the padded data is public, so we can use an if here + if toCheck > len(payload) { + toCheck = len(payload) + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + toRemove = int(paddingLen) + 1 + return +} + +// extractPaddingSSL30 is a replacement for extractPadding in the case that the +// protocol version is SSLv3. In this version, the contents of the padding +// are random and cannot be checked. +func extractPaddingSSL30(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := int(payload[len(payload)-1]) + 1 + if paddingLen > len(payload) { + return 0, 0 + } + + return paddingLen, 255 +} + +func roundUp(a, b int) int { + return a + (b-a%b)%b +} + +// cbcMode is an interface for block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// decrypt checks and strips the mac and decrypts the data in b. Returns a +// success boolean, the number of bytes to skip from the start of the record in +// order to get the application payload, and an optional alert value. +func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) { + // pull out payload + payload := b.data[recordHeaderLen:] + + macSize := 0 + if hc.mac != nil { + macSize = hc.mac.Size() + } + + paddingGood := byte(255) + paddingLen := 0 + explicitIVLen := 0 + + // decrypt + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case aead: + explicitIVLen = c.explicitNonceLen() + if len(payload) < explicitIVLen { + return false, 0, alertBadRecordMAC + } + nonce := payload[:explicitIVLen] + payload = payload[explicitIVLen:] + + if len(nonce) == 0 { + nonce = hc.seq[:] + } + + var additionalData []byte + if hc.version < VersionTLS13 { + copy(hc.additionalData[:], hc.seq[:]) + copy(hc.additionalData[8:], b.data[:3]) + n := len(payload) - c.Overhead() + hc.additionalData[11] = byte(n >> 8) + hc.additionalData[12] = byte(n) + additionalData = hc.additionalData[:] + } else { + if len(payload) > int((1<<14)+256) { + return false, 0, alertRecordOverflow + } + // Check AD header, see 5.2 of RFC8446 + additionalData = make([]byte, 5) + additionalData[0] = byte(recordTypeApplicationData) + binary.BigEndian.PutUint16(additionalData[1:], VersionTLS12) + binary.BigEndian.PutUint16(additionalData[3:], uint16(len(payload))) + } + var err error + payload, err = c.Open(payload[:0], nonce, payload, additionalData) + if err != nil { + return false, 0, alertBadRecordMAC + } + b.resize(recordHeaderLen + explicitIVLen + len(payload)) + case cbcMode: + blockSize := c.BlockSize() + if hc.version >= VersionTLS11 { + explicitIVLen = blockSize + } + + if len(payload)%blockSize != 0 || len(payload) < roundUp(explicitIVLen+macSize+1, blockSize) { + return false, 0, alertBadRecordMAC + } + + if explicitIVLen > 0 { + c.SetIV(payload[:explicitIVLen]) + payload = payload[explicitIVLen:] + } + c.CryptBlocks(payload, payload) + if hc.version == VersionSSL30 { + paddingLen, paddingGood = extractPaddingSSL30(payload) + } else { + paddingLen, paddingGood = extractPadding(payload) + + // To protect against CBC padding oracles like Lucky13, the data + // past paddingLen (which is secret) is passed to the MAC + // function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write. + } + default: + panic("unknown cipher type") + } + } + + // check, strip mac + if hc.mac != nil { + if len(payload) < macSize { + return false, 0, alertBadRecordMAC + } + + // strip mac off payload, b.data + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } + b.data[3] = byte(n >> 8) + b.data[4] = byte(n) + remoteMAC := payload[n : n+macSize] + localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:]) + + if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 { + return false, 0, alertBadRecordMAC + } + hc.inDigestBuf = localMAC + + b.resize(recordHeaderLen + explicitIVLen + n) + } + hc.incSeq() + + return true, recordHeaderLen + explicitIVLen, 0 +} + +// padToBlockSize calculates the needed padding block, if any, for a payload. +// On exit, prefix aliases payload and extends to the end of the last full +// block of payload. finalBlock is a fresh slice which contains the contents of +// any suffix of payload as well as the needed padding to make finalBlock a +// full block. +func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) { + overrun := len(payload) % blockSize + paddingLen := blockSize - overrun + prefix = payload[:len(payload)-overrun] + finalBlock = make([]byte, blockSize) + copy(finalBlock, payload[len(payload)-overrun:]) + for i := overrun; i < blockSize; i++ { + finalBlock[i] = byte(paddingLen - 1) + } + return +} + +// encrypt encrypts and macs the data in b. +func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { + // mac + if hc.mac != nil { + mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil) + + n := len(b.data) + b.resize(n + len(mac)) + copy(b.data[n:], mac) + hc.outDigestBuf = mac + } + + payload := b.data[recordHeaderLen:] + + // encrypt + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case aead: + // explicitIVLen is always 0 for TLS1.3 + payloadLen := len(b.data) - recordHeaderLen - explicitIVLen + payloadOffset := recordHeaderLen + explicitIVLen + nonce := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] + if len(nonce) == 0 { + nonce = hc.seq[:] + } + + var additionalData []byte + if hc.version < VersionTLS13 { + // make room in a buffer for payload + MAC + b.resize(len(b.data) + c.Overhead()) + + payload = b.data[payloadOffset : payloadOffset+payloadLen] + copy(hc.additionalData[:], hc.seq[:]) + copy(hc.additionalData[8:], b.data[:3]) + binary.BigEndian.PutUint16(hc.additionalData[11:], uint16(payloadLen)) + additionalData = hc.additionalData[:] + } else { + // make room in a buffer for TLSCiphertext.encrypted_record: + // payload + MAC + extra data if needed + b.resize(len(b.data) + c.Overhead() + 1) + + payload = b.data[payloadOffset : payloadOffset+payloadLen+1] + // 1 byte of content type is appended to payload and encrypted + payload[len(payload)-1] = b.data[0] + + // opaque_type + b.data[0] = byte(recordTypeApplicationData) + + // Add AD header, see 5.2 of RFC8446 + additionalData = make([]byte, 5) + additionalData[0] = b.data[0] + binary.BigEndian.PutUint16(additionalData[1:], VersionTLS12) + binary.BigEndian.PutUint16(additionalData[3:], uint16(len(payload)+c.Overhead())) + } + c.Seal(payload[:0], nonce, payload, additionalData) + case cbcMode: + blockSize := c.BlockSize() + if explicitIVLen > 0 { + c.SetIV(payload[:explicitIVLen]) + payload = payload[explicitIVLen:] + } + prefix, finalBlock := padToBlockSize(payload, blockSize) + b.resize(recordHeaderLen + explicitIVLen + len(prefix) + len(finalBlock)) + c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen:], prefix) + c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen+len(prefix):], finalBlock) + default: + panic("unknown cipher type") + } + } + + // update length to include MAC and any block padding needed. + n := len(b.data) - recordHeaderLen + b.data[3] = byte(n >> 8) + b.data[4] = byte(n) + hc.incSeq() + + return true, 0 +} + +// A block is a simple data buffer. +type block struct { + data []byte + off int // index for Read + link *block +} + +// resize resizes block to be n bytes, growing if necessary. +func (b *block) resize(n int) { + if n > cap(b.data) { + b.reserve(n) + } + b.data = b.data[0:n] +} + +// reserve makes sure that block contains a capacity of at least n bytes. +func (b *block) reserve(n int) { + if cap(b.data) >= n { + return + } + m := cap(b.data) + if m == 0 { + m = 1024 + } + for m < n { + m *= 2 + } + data := make([]byte, len(b.data), m) + copy(data, b.data) + b.data = data +} + +// readFromUntil reads from r into b until b contains at least n bytes +// or else returns an error. +func (b *block) readFromUntil(r io.Reader, n int) error { + // quick case + if len(b.data) >= n { + return nil + } + + // read until have enough. + b.reserve(n) + for { + m, err := r.Read(b.data[len(b.data):cap(b.data)]) + b.data = b.data[0 : len(b.data)+m] + if len(b.data) >= n { + // TODO(bradfitz,agl): slightly suspicious + // that we're throwing away r.Read's err here. + break + } + if err != nil { + return err + } + } + return nil +} + +func (b *block) Read(p []byte) (n int, err error) { + n = copy(p, b.data[b.off:]) + b.off += n + if b.off >= len(b.data) { + err = io.EOF + } + return +} + +// newBlock allocates a new block, from hc's free list if possible. +func (hc *halfConn) newBlock() *block { + b := hc.bfree + if b == nil { + return new(block) + } + hc.bfree = b.link + b.link = nil + b.resize(0) + return b +} + +// freeBlock returns a block to hc's free list. +// The protocol is such that each side only has a block or two on +// its free list at a time, so there's no need to worry about +// trimming the list, etc. +func (hc *halfConn) freeBlock(b *block) { + b.link = hc.bfree + hc.bfree = b +} + +// splitBlock splits a block after the first n bytes, +// returning a block with those n bytes and a +// block with the remainder. the latter may be nil. +func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) { + if len(b.data) <= n { + return b, nil + } + bb := hc.newBlock() + bb.resize(len(b.data) - n) + copy(bb.data, b.data[n:]) + b.data = b.data[0:n] + return b, bb +} + +// RecordHeaderError results when a TLS record header is invalid. +type RecordHeaderError struct { + // Msg contains a human readable string that describes the error. + Msg string + // RecordHeader contains the five bytes of TLS record header that + // triggered the error. + RecordHeader [5]byte +} + +func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } + +func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) { + err.Msg = msg + copy(err.RecordHeader[:], c.rawInput.data) + return err +} + +// readRecord reads the next TLS record from the connection +// and updates the record layer state. +// c.in.Mutex <= L; c.input == nil. +// c.input can still be nil after a call, retry if so. +func (c *Conn) readRecord(want recordType) error { + // Caller must be in sync with connection: + // handshake data if handshake not yet completed, + // else application data. + switch want { + default: + c.sendAlert(alertInternalError) + return c.in.setErrorLocked(errors.New("tls: unknown record type requested")) + case recordTypeHandshake, recordTypeChangeCipherSpec: + if c.phase != handshakeRunning && c.phase != readingClientFinished { + c.sendAlert(alertInternalError) + return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake")) + } + case recordTypeApplicationData: + if c.phase == handshakeRunning || c.phase == readingClientFinished { + c.sendAlert(alertInternalError) + return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake")) + } + } + +Again: + if c.rawInput == nil { + c.rawInput = c.in.newBlock() + } + b := c.rawInput + + // Read header, payload. + if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil { + // RFC suggests that EOF without an alertCloseNotify is + // an error, but popular web sites seem to do this, + // so we can't make it an error. + // if err == io.EOF { + // err = io.ErrUnexpectedEOF + // } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + typ := recordType(b.data[0]) + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + if want == recordTypeHandshake && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received")) + } + + vers := uint16(b.data[1])<<8 | uint16(b.data[2]) + n := int(b.data[3])<<8 | int(b.data[4]) + if n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(msg)) + } + if !c.haveVers { + // First message, be extra suspicious: this might not be a TLS + // client. Bail out before reading a full 'body', if possible. + // The current max version is 3.3 so if the version is >= 16.0, + // it's probably not real. + if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake")) + } + } + if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + + // Process message. + b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) + + // TLS 1.3 middlebox compatibility: skip over unencrypted CCS. + if c.vers >= VersionTLS13 && typ == recordTypeChangeCipherSpec && c.phase != handshakeConfirmed { + if len(b.data) != 6 || b.data[5] != 1 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.in.freeBlock(b) + return c.in.err + } + + peekedAlert := peekAlert(b) // peek at a possible alert before decryption + ok, off, alertValue := c.in.decrypt(b) + switch { + case !ok && c.phase == discardingEarlyData: + // If the client said that it's sending early data and we did not + // accept it, we are expected to fail decryption. + c.in.freeBlock(b) + return nil + case ok && c.phase == discardingEarlyData: + c.phase = waitingClientFinished + case !ok: + c.in.traceErr, c.out.traceErr = nil, nil // not that interesting + c.in.freeBlock(b) + err := c.sendAlert(alertValue) + // If decryption failed because the message is an unencrypted + // alert, return a more meaningful error message + if alertValue == alertBadRecordMAC && peekedAlert != nil { + err = peekedAlert + } + return c.in.setErrorLocked(err) + } + b.off = off + data := b.data[b.off:] + if (c.vers < VersionTLS13 && len(data) > maxPlaintext) || len(data) > maxPlaintext+1 { + c.in.freeBlock(b) + return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) + } + + // After checking the plaintext length, remove 1.3 padding and + // extract the real content type. + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-5.4. + if c.vers >= VersionTLS13 { + i := len(data) - 1 + for i >= 0 { + if data[i] != 0 { + break + } + i-- + } + if i < 0 { + c.in.freeBlock(b) + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + typ = recordType(data[i]) + data = data[:i] + b.resize(b.off + i) // shrinks, guaranteed not to reallocate + } + + if typ != recordTypeAlert && len(data) > 0 { + // this is a valid non-alert message: reset the count of alerts + c.warnCount = 0 + } + + switch typ { + default: + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + + case recordTypeAlert: + if len(data) != 2 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } + if alert(data[1]) == alertCloseNotify { + c.in.setErrorLocked(io.EOF) + break + } + switch data[0] { + case alertLevelWarning: + // drop on the floor + c.in.freeBlock(b) + + c.warnCount++ + if c.warnCount > maxWarnAlertCount { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many warn alerts")) + } + + goto Again + case alertLevelError: + c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + default: + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if typ != want || len(data) != 1 || data[0] != 1 || c.vers >= VersionTLS13 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } + // Handshake messages are not allowed to fragment across the CCS + if c.hand.Len() > 0 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } + // Handshake messages are not allowed to fragment across the CCS + if c.hand.Len() > 0 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } + err := c.in.changeCipherSpec() + if err != nil { + c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + + case recordTypeApplicationData: + if typ != want || c.phase == waitingClientFinished { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } + if c.phase == readingEarlyData { + c.earlyDataBytes += int64(len(b.data) - b.off) + if c.earlyDataBytes > c.ticketMaxEarlyData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + } + c.input = b + b = nil + + case recordTypeHandshake: + // TODO(rsc): Should at least pick off connection close. + // If early data was being read, a Finished message is expected + // instead of (early) application data. Other post-handshake + // messages include HelloRequest and NewSessionTicket. + if typ != want && want != recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.hand.Write(data) + } + + if b != nil { + c.in.freeBlock(b) + } + return c.in.err +} + +// peekAlert looks at a message to spot an unencrypted alert. It must be +// called before decryption to avoid a side channel, and its result must +// only be used if decryption fails, to avoid false positives. +func peekAlert(b *block) error { + if len(b.data) < 7 { + return nil + } + if recordType(b.data[0]) != recordTypeAlert { + return nil + } + return &net.OpError{Op: "remote error", Err: alert(b.data[6])} +} + +// sendAlert sends a TLS alert message. +// c.out.Mutex <= L. +func (c *Conn) sendAlertLocked(err alert) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr + } + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) +} + +// sendAlert sends a TLS alert message. +// L < c.out.Mutex. +func (c *Conn) sendAlert(err alert) error { + if c.config.AlternativeRecordLayer != nil { + return nil + } + c.out.Lock() + defer c.out.Unlock() + return c.sendAlertLocked(err) +} + +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the +// next application data record. There is the following trade-off: +// +// - For latency-sensitive applications, such as web browsing, each TLS +// record should fit in one TCP segment. +// - For throughput-sensitive applications, such as large file transfers, +// larger TLS records better amortize framing and encryption overheads. +// +// A simple heuristic that works well in practice is to use small records for +// the first 1MB of data, then use larger records for subsequent data, and +// reset back to smaller records after the connection becomes idle. See "High +// Performance Web Networking", Chapter 4, or: +// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ +// +// In the interests of simplicity and determinism, this code does not attempt +// to reset the record size once the connection is idle, however. +// +// c.out.Mutex <= L. +func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int { + if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { + return maxPlaintext + } + + if c.bytesSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + macSize := 0 + if c.out.mac != nil { + macSize = c.out.mac.Size() + } + + payloadBytes := tcpMSSEstimate - recordHeaderLen - explicitIVLen + if c.out.cipher != nil { + switch ciph := c.out.cipher.(type) { + case cipher.Stream: + payloadBytes -= macSize + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + if c.vers >= VersionTLS13 { + payloadBytes -= 1 // ContentType + } + case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The MAC is appended before padding so affects the + // payload size directly. + payloadBytes -= macSize + default: + panic("unknown cipher type") + } + } + + // Allow packet growth in arithmetic progression up to max. + pkt := c.packetsSent + c.packetsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} + +// c.out.Mutex <= L. +func (c *Conn) write(data []byte) (int, error) { + if c.buffering { + c.sendBuf = append(c.sendBuf, data...) + return len(data), nil + } + + n, err := c.conn.Write(data) + c.bytesSent += int64(n) + return n, err +} + +func (c *Conn) flush() (int, error) { + if len(c.sendBuf) == 0 { + return 0, nil + } + + n, err := c.conn.Write(c.sendBuf) + c.bytesSent += int64(n) + c.sendBuf = nil + c.buffering = false + return n, err +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +// c.out.Mutex <= L. +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + b := c.out.newBlock() + defer c.out.freeBlock(b) + + var n int + for len(data) > 0 { + explicitIVLen := 0 + explicitIVIsSeq := false + + var cbc cbcMode + if c.out.version >= VersionTLS11 { + var ok bool + if cbc, ok = c.out.cipher.(cbcMode); ok { + explicitIVLen = cbc.BlockSize() + } + } + if explicitIVLen == 0 { + if c, ok := c.out.cipher.(aead); ok { + explicitIVLen = c.explicitNonceLen() + + // The AES-GCM construction in TLS has an + // explicit nonce so that the nonce can be + // random. However, the nonce is only 8 bytes + // which is too small for a secure, random + // nonce. Therefore we use the sequence number + // as the nonce. + explicitIVIsSeq = explicitIVLen > 0 + } + } + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ, explicitIVLen); m > maxPayload { + m = maxPayload + } + b.resize(recordHeaderLen + explicitIVLen + m) + b.data[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } + if c.vers >= VersionTLS13 { + // TLS 1.3 froze the record layer version at { 3, 1 }. + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-5.1. + // But for draft 22, this was changed to { 3, 3 }. + vers = VersionTLS12 + } + b.data[1] = byte(vers >> 8) + b.data[2] = byte(vers) + b.data[3] = byte(m >> 8) + b.data[4] = byte(m) + if explicitIVLen > 0 { + explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] + if explicitIVIsSeq { + copy(explicitIV, c.out.seq[:]) + } else { + if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil { + return n, err + } + } + } + copy(b.data[recordHeaderLen+explicitIVLen:], data) + c.out.encrypt(b, explicitIVLen) + if _, err := c.write(b.data); err != nil { + return n, err + } + n += m + data = data[m:] + } + + if typ == recordTypeChangeCipherSpec && c.vers < VersionTLS13 { + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) + } + } + + return n, nil +} + +// writeRecord writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +// L < c.out.Mutex. +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { + if c.config.AlternativeRecordLayer != nil { + if typ == recordTypeChangeCipherSpec { + return len(data), nil + } + return c.config.AlternativeRecordLayer.WriteRecord(data) + } + + c.out.Lock() + defer c.out.Unlock() + + return c.writeRecordLocked(typ, data) +} + +// readHandshake reads the next handshake message from +// the record layer. +// c.in.Mutex < L; c.out.Mutex < L. +func (c *Conn) readHandshake() (interface{}, error) { + var data []byte + if c.config.AlternativeRecordLayer != nil { + var err error + data, err = c.config.AlternativeRecordLayer.ReadHandshakeMessage() + if err != nil { + return nil, err + } + } else { + for c.hand.Len() < 4 { + if err := c.in.err; err != nil { + return nil, err + } + if err := c.readRecord(recordTypeHandshake); err != nil { + return nil, err + } + } + + data = c.hand.Bytes() + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + c.sendAlertLocked(alertInternalError) + return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) + } + for c.hand.Len() < 4+n { + if err := c.in.err; err != nil { + return nil, err + } + if err := c.readRecord(recordTypeHandshake); err != nil { + return nil, err + } + } + data = c.hand.Next(4 + n) + } + var m handshakeMessage + switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) + case typeClientHello: + m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + case typeNewSessionTicket: + if c.vers >= VersionTLS13 { + m = new(newSessionTicketMsg13) + } else { + m = new(newSessionTicketMsg) + } + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) + case typeCertificate: + if c.vers >= VersionTLS13 { + m = new(certificateMsg13) + } else { + m = new(certificateMsg) + } + case typeCertificateRequest: + if c.vers >= VersionTLS13 { + m = new(certificateRequestMsg13) + } else { + m = &certificateRequestMsg{ + hasSignatureAndHash: c.vers >= VersionTLS12, + } + } + case typeCertificateStatus: + m = new(certificateStatusMsg) + case typeServerKeyExchange: + m = new(serverKeyExchangeMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeCertificateVerify: + m = &certificateVerifyMsg{ + hasSignatureAndHash: c.vers >= VersionTLS12, + } + case typeNextProtocol: + m = new(nextProtoMsg) + case typeFinished: + m = new(finishedMsg) + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshalers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if unmarshalAlert := m.unmarshal(data); unmarshalAlert != alertSuccess { + return nil, c.in.setErrorLocked(c.sendAlert(unmarshalAlert)) + } + return m, nil +} + +var ( + errClosed = errors.New("tls: use of closed connection") + errShutdown = errors.New("tls: protocol is shutdown") +) + +// Write writes data to the connection. +func (c *Conn) Write(b []byte) (int, error) { + // interlock with Close below + for { + x := atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return 0, errClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { + defer atomic.AddInt32(&c.activeCall, -2) + break + } + } + + if err := c.Handshake(); err != nil { + return 0, err + } + + c.out.Lock() + defer c.out.Unlock() + + if err := c.out.err; err != nil { + return 0, err + } + + if !c.handshakeComplete { + return 0, alertInternalError + } + + if c.closeNotifySent { + return 0, errShutdown + } + + // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the IV. + // + // http://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // http://www.imperialviolet.org/2012/01/15/beastfollowup.html + + var m int + if len(b) > 1 && c.vers <= VersionTLS10 { + if _, ok := c.out.cipher.(cipher.BlockMode); ok { + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.out.setErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecordLocked(recordTypeApplicationData, b) + return n + m, c.out.setErrorLocked(err) +} + +// Process Handshake messages after the handshake has completed. +// c.in.Mutex <= L +func (c *Conn) handlePostHandshake() error { + msg, err := c.readHandshake() + if err != nil { + return err + } + + switch hm := msg.(type) { + case *helloRequestMsg: + return c.handleRenegotiation(hm) + case *newSessionTicketMsg13: + if !c.isClient { + c.sendAlert(alertUnexpectedMessage) + return alertUnexpectedMessage + } + return nil // TODO implement session tickets + default: + c.sendAlert(alertUnexpectedMessage) + return alertUnexpectedMessage + } +} + +// handleRenegotiation processes a HelloRequest handshake message. +// c.in.Mutex <= L +func (c *Conn) handleRenegotiation(*helloRequestMsg) error { + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + if c.vers >= VersionTLS13 { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + c.phase = handshakeRunning + c.handshakeComplete = false + if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + +func (c *Conn) setAlternativeRecordLayer() { + if c.config.AlternativeRecordLayer != nil { + c.in.setKeyCallback = c.config.AlternativeRecordLayer.SetReadKey + c.out.setKeyCallback = c.config.AlternativeRecordLayer.SetWriteKey + } +} + +// ConfirmHandshake waits for the handshake to reach a point at which +// the connection is certainly not replayed. That is, after receiving +// the Client Finished. +// +// If ConfirmHandshake returns an error and until ConfirmHandshake +// returns, the 0-RTT data should not be trusted not to be replayed. +// +// This is only meaningful in TLS 1.3 when Accept0RTTData is true and the +// client sent valid 0-RTT data. In any other case it's equivalent to +// calling Handshake. +func (c *Conn) ConfirmHandshake() error { + if c.isClient { + panic("ConfirmHandshake should only be called for servers") + } + + if err := c.Handshake(); err != nil { + return err + } + + if c.vers < VersionTLS13 { + return nil + } + + c.confirmMutex.Lock() + if atomic.LoadInt32(&c.handshakeConfirmed) == 1 { // c.phase == handshakeConfirmed + c.confirmMutex.Unlock() + return nil + } else { + defer func() { + // If we transitioned to handshakeConfirmed we already released the lock, + // otherwise do it here. + if c.phase != handshakeConfirmed { + c.confirmMutex.Unlock() + } + }() + } + + c.in.Lock() + defer c.in.Unlock() + + var input *block + // Try to read all data (if phase==readingEarlyData) or extract the + // remaining data from the previous read that could not fit in the read + // buffer (if c.input != nil). + if c.phase == readingEarlyData || c.input != nil { + buf := &bytes.Buffer{} + if _, err := buf.ReadFrom(earlyDataReader{c}); err != nil { + c.in.setErrorLocked(err) + return err + } + input = &block{data: buf.Bytes()} + } + + // At this point, earlyDataReader has read all early data and received + // the end_of_early_data signal. Expect a Finished message. + // Locks held so far: c.confirmMutex, c.in + // not confirmed implies c.phase == discardingEarlyData || c.phase == waitingClientFinished + for c.phase != handshakeConfirmed { + if err := c.hs.readClientFinished13(true); err != nil { + c.in.setErrorLocked(err) + return err + } + } + + if c.phase != handshakeConfirmed { + panic("should have reached handshakeConfirmed state") + } + if c.input != nil { + panic("should not have read past the Client Finished") + } + + c.input = input + + return nil +} + +// earlyDataReader wraps a Conn and reads only early data, both buffered +// and still on the wire. +type earlyDataReader struct { + c *Conn +} + +// c.in.Mutex <= L +func (r earlyDataReader) Read(b []byte) (n int, err error) { + c := r.c + + if c.phase == handshakeConfirmed { + // c.input might not be early data + panic("earlyDataReader called at handshakeConfirmed") + } + + for c.input == nil && c.in.err == nil && c.phase == readingEarlyData { + if err := c.readRecord(recordTypeApplicationData); err != nil { + return 0, err + } + if c.hand.Len() > 0 { + if err := c.handleEndOfEarlyData(); err != nil { + return 0, err + } + } + } + if err := c.in.err; err != nil { + return 0, err + } + + if c.input != nil { + n, err = c.input.Read(b) + if err == io.EOF { + err = nil + c.in.freeBlock(c.input) + c.input = nil + } + } + + // Following early application data, an end_of_early_data is expected. + if err == nil && c.phase != readingEarlyData && c.input == nil { + err = io.EOF + } + return +} + +// Read can be made to time out and return a net.Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetReadDeadline. +func (c *Conn) Read(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return + } + + c.confirmMutex.Lock() + if atomic.LoadInt32(&c.handshakeConfirmed) == 1 { // c.phase == handshakeConfirmed + c.confirmMutex.Unlock() + } else { + defer func() { + // If we transitioned to handshakeConfirmed we already released the lock, + // otherwise do it here. + if c.phase != handshakeConfirmed { + c.confirmMutex.Unlock() + } + }() + } + + c.in.Lock() + defer c.in.Unlock() + + // Some OpenSSL servers send empty records in order to randomize the + // CBC IV. So this loop ignores a limited number of empty records. + const maxConsecutiveEmptyRecords = 100 + for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ { + for c.input == nil && c.in.err == nil { + if err := c.readRecord(recordTypeApplicationData); err != nil { + // Soft error, like EAGAIN + return 0, err + } + if c.hand.Len() > 0 { + if c.phase == readingEarlyData || c.phase == waitingClientFinished { + if c.phase == readingEarlyData { + if err := c.handleEndOfEarlyData(); err != nil { + return 0, err + } + } + // Server has received all early data, confirm + // by reading the Client Finished message. + if err := c.hs.readClientFinished13(true); err != nil { + c.in.setErrorLocked(err) + return 0, err + } + continue + } + if err := c.handlePostHandshake(); err != nil { + return 0, err + } + } + } + if err := c.in.err; err != nil { + return 0, err + } + + n, err = c.input.Read(b) + if err == io.EOF { + err = nil + c.in.freeBlock(c.input) + c.input = nil + } + + // If a close-notify alert is waiting, read it so that + // we can return (n, EOF) instead of (n, nil), to signal + // to the HTTP response reading goroutine that the + // connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would + // otherwise not observe the EOF until its next read, + // by which time a client goroutine might have already + // tried to reuse the HTTP connection for a new + // request. + // See https://codereview.appspot.com/76400046 + // and https://golang.org/issue/3514 + if ri := c.rawInput; ri != nil && + n != 0 && err == nil && + c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert { + if recErr := c.readRecord(recordTypeApplicationData); recErr != nil { + err = recErr // will be io.EOF on closeNotify + } + } + + if n != 0 || err != nil { + return n, err + } + } + + return 0, io.ErrNoProgress +} + +// Close closes the connection. +func (c *Conn) Close() error { + // Interlock with Conn.Write above. + var x int32 + for { + x = atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return errClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { + break + } + } + if x != 0 { + // io.Writer and io.Closer should not be used concurrently. + // If Close is called while a Write is currently in-flight, + // interpret that as a sign that this Close is really just + // being used to break the Write and/or clean up resources and + // avoid sending the alertCloseNotify, which may block + // waiting on handshakeMutex or the c.out mutex. + return c.conn.Close() + } + + var alertErr error + + c.handshakeMutex.Lock() + if c.handshakeComplete { + alertErr = c.closeNotify() + } + c.handshakeMutex.Unlock() + + if err := c.conn.Close(); err != nil { + return err + } + return alertErr +} + +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.handshakeComplete { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + c.out.Lock() + defer c.out.Unlock() + + if !c.closeNotifySent { + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + } + return c.closeNotifyErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// Most uses of this package need not call Handshake +// explicitly: the first Read or Write will call it automatically. +// +// In TLS 1.3 Handshake returns after the client and server first flights, +// without waiting for the Client Finished. +func (c *Conn) Handshake() error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeComplete { + return nil + } + + c.in.Lock() + defer c.in.Unlock() + + // The handshake cannot have completed when handshakeMutex was unlocked + // because this goroutine set handshakeCond. + if c.handshakeErr != nil || c.handshakeComplete { + panic("handshake should not have been able to complete after handshakeCond was set") + } + + c.connID = make([]byte, 8) + if _, err := io.ReadFull(c.config.rand(), c.connID); err != nil { + return err + } + + if c.isClient { + c.handshakeErr = c.clientHandshake() + } else { + c.handshakeErr = c.serverHandshake() + } + if c.handshakeErr == nil { + c.handshakes++ + } else { + // If an error occurred during the hadshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + if c.handshakeErr == nil && !c.handshakeComplete { + panic("handshake should have had a result.") + } + + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + var state ConnectionState + state.HandshakeComplete = c.handshakeComplete + state.ServerName = c.serverName + + if c.handshakeComplete { + state.ConnectionID = c.connID + state.ClientHello = c.clientHello + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if c.verifiedDc != nil { + state.DelegatedCredential = c.verifiedDc.raw + } + state.HandshakeConfirmed = atomic.LoadInt32(&c.handshakeConfirmed) == 1 + if !state.HandshakeConfirmed { + state.Unique0RTTToken = c.binder + } + if !c.didResume { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } + } + } + + return state +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) VerifyHostname(host string) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.isClient { + return errors.New("tls: VerifyHostname called on TLS server connection") + } + if !c.handshakeComplete { + return errors.New("tls: handshake has not yet been performed") + } + if len(c.verifiedChains) == 0 { + return errors.New("tls: handshake did not verify certificate chain") + } + return c.peerCertificates[0].VerifyHostname(host) +} diff --git a/vendor/github.com/marten-seemann/qtls/handshake_client.go b/vendor/github.com/marten-seemann/qtls/handshake_client.go new file mode 100644 index 00000000..b80f2554 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/handshake_client.go @@ -0,0 +1,1006 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync/atomic" +) + +type clientHandshakeState struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + masterSecret []byte + session *ClientSessionState + + // TLS 1.0-1.2 fields + finishedHash finishedHash + + // TLS 1.3 fields + keySchedule *keySchedule13 + privateKey []byte +} + +func makeClientHello(config *Config) (*clientHelloMsg, error) { + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + + nextProtosLength := 0 + for _, proto := range config.NextProtos { + if l := len(proto); l == 0 || l > 255 { + return nil, errors.New("tls: invalid NextProtos value") + } else { + nextProtosLength += 1 + l + } + } + + if nextProtosLength > 0xffff { + return nil, errors.New("tls: NextProtos values too large") + } + + hello := &clientHelloMsg{ + vers: config.maxVersion(), + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + ocspStapling: true, + scts: true, + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + nextProtoNeg: len(config.NextProtos) > 0, + secureRenegotiationSupported: true, + delegatedCredential: config.AcceptDelegatedCredential, + alpnProtocols: config.NextProtos, + extendedMSSupported: config.UseExtendedMasterSecret, + } + possibleCipherSuites := config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) + +NextCipherSuite: + for _, suiteId := range possibleCipherSuites { + for _, suite := range cipherSuites { + if suite.id != suiteId { + continue + } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + continue NextCipherSuite + } + // Don't advertise TLS 1.3-only cipher suites unless + // we're attempting TLS 1.3. + if hello.vers < VersionTLS13 && suite.flags&suiteTLS13 != 0 { + continue NextCipherSuite + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) + continue NextCipherSuite + } + } + + _, err := io.ReadFull(config.rand(), hello.random) + if err != nil { + return nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + if hello.vers >= VersionTLS12 { + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + if hello.vers >= VersionTLS13 { + // Version preference is indicated via "supported_extensions", + // set legacy_version to TLS 1.2 for backwards compatibility. + hello.vers = VersionTLS12 + hello.supportedVersions = config.getSupportedVersions() + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms13 + hello.supportedSignatureAlgorithmsCert = supportedSigAlgorithmsCert(supportedSignatureAlgorithms13) + if config.GetExtensions != nil { + hello.additionalExtensions = config.GetExtensions(typeClientHello) + } + } + + return hello, nil +} + +// c.out.Mutex <= L; c.handshakeMutex <= L. +func (c *Conn) clientHandshake() error { + if c.config == nil { + c.config = defaultConfig() + } + c.setAlternativeRecordLayer() + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, err := makeClientHello(c.config) + if err != nil { + return err + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } + + var session *ClientSessionState + var cacheKey string + sessionCache := c.config.ClientSessionCache + // TLS 1.3 has no session resumption based on session tickets. + if c.config.SessionTicketsDisabled || c.config.maxVersion() >= VersionTLS13 { + sessionCache = nil + } + + if sessionCache != nil { + hello.ticketSupported = true + } + + // Session resumption is not allowed if renegotiating because + // renegotiation is primarily used to allow a client to send a client + // certificate, which would be skipped if session resumption occurred. + if sessionCache != nil && c.handshakes == 0 { + // Try to resume a previously negotiated TLS session, if + // available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + candidateSession, ok := sessionCache.Get(cacheKey) + if ok { + // Check that the ciphersuite/version used for the + // previous session are still valid. + cipherSuiteOk := false + for _, id := range hello.cipherSuites { + if id == candidateSession.cipherSuite { + cipherSuiteOk = true + break + } + } + + versOk := candidateSession.vers >= c.config.minVersion() && + candidateSession.vers <= c.config.maxVersion() + if versOk && cipherSuiteOk { + session = candidateSession + } + } + } + + if session != nil { + hello.sessionTicket = session.sessionTicket + // A random session ID is used to detect when the + // server accepted the ticket and is resuming a session + // (see RFC 5077). + hello.sessionId = make([]byte, 16) + if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil { + return errors.New("tls: short read from Rand: " + err.Error()) + } + } + + hs := &clientHandshakeState{ + c: c, + hello: hello, + session: session, + } + + var clientKS keyShare + if c.config.maxVersion() >= VersionTLS13 { + // Create one keyshare for the first default curve. If it is not + // appropriate, the server should raise a HRR. + defaultGroup := c.config.curvePreferences()[0] + hs.privateKey, clientKS, err = c.config.generateKeyShare(defaultGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hello.keyShares = []keyShare{clientKS} + // middlebox compatibility mode, provide a non-empty session ID + hello.sessionId = make([]byte, 16) + if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil { + return errors.New("tls: short read from Rand: " + err.Error()) + } + } + + if err = hs.handshake(); err != nil { + return err + } + + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one + if sessionCache != nil && hs.session != nil && session != hs.session && c.vers < VersionTLS13 { + sessionCache.Put(cacheKey, hs.session) + } + + return nil +} + +// Does the handshake, either a full one or resumes old session. +// Requires hs.c, hs.hello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() error { + c := hs.c + + // send ClientHello + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + var ok bool + if hs.serverHello, ok = msg.(*serverHelloMsg); !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(hs.serverHello, msg) + } + + if err = hs.pickTLSVersion(); err != nil { + return err + } + + if err = hs.pickCipherSuite(); err != nil { + return err + } + + var isResume bool + if c.vers >= VersionTLS13 { + hs.keySchedule = newKeySchedule13(hs.suite, c.config, hs.hello.random) + hs.keySchedule.write(hs.hello.marshal()) + hs.keySchedule.write(hs.serverHello.marshal()) + } else { + isResume, err = hs.processServerHello() + if err != nil { + return err + } + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + + // No signatures of the handshake are needed in a resumption. + // Otherwise, in a full handshake, if we don't have any certificates + // configured then we will never send a CertificateVerify message and + // thus no signatures are needed in that case either. + if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { + hs.finishedHash.discardHandshakeBuffer() + } + + hs.finishedHash.Write(hs.hello.marshal()) + hs.finishedHash.Write(hs.serverHello.marshal()) + } + + c.buffering = true + if c.vers >= VersionTLS13 { + if err := hs.doTLS13Handshake(); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } else if isResume { + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = false + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } else { + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = true + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + } + + c.didResume = isResume + c.phase = handshakeConfirmed + atomic.StoreInt32(&c.handshakeConfirmed, 1) + c.handshakeComplete = true + + return nil +} + +func (hs *clientHandshakeState) pickTLSVersion() error { + vers, ok := hs.c.config.pickVersion([]uint16{hs.serverHello.vers}) + if !ok || vers < VersionTLS10 { + // TLS 1.0 is the minimum version supported as a client. + hs.c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", hs.serverHello.vers) + } + + hs.c.vers = vers + hs.c.haveVers = true + + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + // Check that the chosen cipher suite matches the protocol version. + if hs.c.vers >= VersionTLS13 && hs.suite.flags&suiteTLS13 == 0 || + hs.c.vers < VersionTLS13 && hs.suite.flags&suiteTLS13 != 0 { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an inappropriate cipher suite") + } + + hs.c.cipherSuite = hs.suite.id + return nil +} + +// processCertsFromServer takes a chain of server certificates from a +// Certificate message and verifies them. +func (hs *clientHandshakeState) processCertsFromServer(certificates [][]byte) error { + c := hs.c + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + if !c.config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), + } + + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + var err error + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + switch certs[0].PublicKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey: + break + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) + } + + c.peerCertificates = certs + return nil +} + +// processDelegatedCredentialFromServer unmarshals the delegated credential +// offered by the server (if present) and validates it using the peer +// certificate and the signature scheme (`scheme`) indicated by the server in +// the "signature_scheme" extension. +func (hs *clientHandshakeState) processDelegatedCredentialFromServer(serialized []byte, scheme SignatureScheme) error { + c := hs.c + + var dc *delegatedCredential + var err error + if serialized != nil { + // Assert that the DC extension was indicated by the client. + if !hs.hello.delegatedCredential { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: got delegated credential extension without indication") + } + + // Parse the delegated credential. + dc, err = unmarshalDelegatedCredential(serialized) + if err != nil { + c.sendAlert(alertDecodeError) + return fmt.Errorf("tls: delegated credential: %s", err) + } + } + + if dc != nil && !c.config.InsecureSkipVerify { + if v, err := dc.validate(c.peerCertificates[0], c.config.time()); err != nil { + c.sendAlert(alertIllegalParameter) + return fmt.Errorf("delegated credential: %s", err) + } else if !v { + c.sendAlert(alertIllegalParameter) + return errors.New("delegated credential: signature invalid") + } else if dc.cred.expectedVersion != hs.c.vers { + c.sendAlert(alertIllegalParameter) + return errors.New("delegated credential: protocol version mismatch") + } else if dc.cred.expectedCertVerifyAlgorithm != scheme { + c.sendAlert(alertIllegalParameter) + return errors.New("delegated credential: signature scheme mismatch") + } + } + + c.verifiedDc = dc + return nil +} + +func (hs *clientHandshakeState) doFullHandshake() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + certMsg, ok := msg.(*certificateMsg) + if !ok || len(certMsg.certificates) == 0 { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + if c.handshakes == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := hs.processCertsFromServer(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + + cs, ok := msg.(*certificateStatusMsg) + if ok { + // RFC4366 on Certificate Status Request: + // The server MAY return a "certificate_status" message. + + if !hs.serverHello.ocspStapling { + // If a server returns a "CertificateStatus" message, then the + // server MUST have included an extension of type "status_request" + // with empty "extension_data" in the extended server hello. + + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received unexpected CertificateStatus message") + } + hs.finishedHash.Write(cs.marshal()) + + if cs.statusType == statusTypeOCSP { + c.ocspResponse = cs.response + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + keyAgreement := hs.suite.ka(c.vers) + + // Set the public key used to verify the handshake. + pk := c.peerCertificates[0].PublicKey + + skx, ok := msg.(*serverKeyExchangeMsg) + if ok { + hs.finishedHash.Write(skx.marshal()) + + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, pk, skx) + if err != nil { + c.sendAlert(alertUnexpectedMessage) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + var chainToSend *Certificate + var certRequested bool + certReq, ok := msg.(*certificateRequestMsg) + if ok { + certRequested = true + hs.finishedHash.Write(certReq.marshal()) + + if chainToSend, err = hs.getCertificate(certReq); err != nil { + c.sendAlert(alertInternalError) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + shd, ok := msg.(*serverHelloDoneMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) + } + hs.finishedHash.Write(shd.marshal()) + + // If the server requested a certificate then we have to send a + // Certificate message, even if it's empty because we don't have a + // certificate to send. + if certRequested { + certMsg = new(certificateMsg) + certMsg.certificates = chainToSend.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + } + + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, pk) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ckx != nil { + hs.finishedHash.Write(ckx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + return err + } + } + c.useEMS = hs.serverHello.extendedMSSupported + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random, hs.finishedHash, c.useEMS) + + if err := c.config.writeKeyLog("CLIENT_RANDOM", hs.hello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to write to key log: " + err.Error()) + } + + if chainToSend != nil && len(chainToSend.Certificate) > 0 { + certVerify := &certificateVerifyMsg{ + hasSignatureAndHash: c.vers >= VersionTLS12, + } + + key, ok := chainToSend.PrivateKey.(crypto.Signer) + if !ok { + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) + } + + signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(key.Public(), certReq.supportedSignatureAlgorithms, hs.hello.supportedSignatureAlgorithms, c.vers) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + // SignatureAndHashAlgorithm was introduced in TLS 1.2. + if certVerify.hasSignatureAndHash { + certVerify.signatureAlgorithm = signatureAlgorithm + } + digest, err := hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + signOpts := crypto.SignerOpts(hashFunc) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunc} + } + certVerify.signature, err = key.Sign(c.config.rand(), digest, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.finishedHash.Write(certVerify.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + return err + } + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher interface{} + var clientHash, serverHash macFunction + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(c.vers, clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(c.vers, serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) + return nil +} + +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} + +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") + } + + if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { + c.secureRenegotiation = true + if len(hs.serverHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: initial handshake had non-empty renegotiation extension") + } + } + + if c.handshakes > 0 && c.secureRenegotiation { + var expectedSecureRenegotiation [24]byte + copy(expectedSecureRenegotiation[:], c.clientFinished[:]) + copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) + if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: incorrect renegotiation extension contents") + } + } + + if hs.serverHello.extendedMSSupported { + if hs.hello.extendedMSSupported { + c.useEMS = true + } else { + // server wants to calculate master secret in a different way than client + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: unexpected extension (EMS) received in SH") + } + } + + clientDidNPN := hs.hello.nextProtoNeg + clientDidALPN := len(hs.hello.alpnProtocols) > 0 + serverHasNPN := hs.serverHello.nextProtoNeg + serverHasALPN := len(hs.serverHello.alpnProtocol) > 0 + + if !clientDidNPN && serverHasNPN { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server advertised unrequested NPN extension") + } + + if !clientDidALPN && serverHasALPN { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server advertised unrequested ALPN extension") + } + + if serverHasNPN && serverHasALPN { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server advertised both NPN and ALPN extensions") + } + + if serverHasALPN { + c.clientProtocol = hs.serverHello.alpnProtocol + c.clientProtocolFallback = false + } + c.scts = hs.serverHello.scts + + if !hs.serverResumedSession() { + return false, nil + } + + if hs.session.useEMS != c.useEMS { + return false, errors.New("differing EMS state") + } + + if hs.session.vers != c.vers { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different version") + } + + if hs.session.cipherSuite != hs.suite.id { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different cipher suite") + } + + // Restore masterSecret and peerCerts from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + return true, nil +} + +func (hs *clientHandshakeState) readFinished(out []byte) error { + c := hs.c + + c.readRecord(recordTypeChangeCipherSpec) + if c.in.err != nil { + return c.in.err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + + verify := hs.finishedHash.serverSum(hs.masterSecret) + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + c.sendAlert(alertDecryptError) + return errors.New("tls: server's Finished message was incorrect") + } + hs.finishedHash.Write(serverFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + if !hs.serverHello.ticketSupported { + return nil + } + + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + hs.finishedHash.Write(sessionTicketMsg.marshal()) + + hs.session = &ClientSessionState{ + sessionTicket: sessionTicketMsg.ticket, + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + useEMS: c.useEMS, + } + + return nil +} + +func (hs *clientHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + if hs.serverHello.nextProtoNeg { + nextProto := new(nextProtoMsg) + proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos) + nextProto.proto = proto + c.clientProtocol = proto + c.clientProtocolFallback = fallback + + hs.finishedHash.Write(nextProto.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil { + return err + } + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + copy(out, finished.verifyData) + return nil +} + +// tls11SignatureSchemes contains the signature schemes that we synthesise for +// a TLS <= 1.1 connection, based on the supported certificate types. +var tls11SignatureSchemes = []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1} + +const ( + // tls11SignatureSchemesNumECDSA is the number of initial elements of + // tls11SignatureSchemes that use ECDSA. + tls11SignatureSchemesNumECDSA = 3 + // tls11SignatureSchemesNumRSA is the number of trailing elements of + // tls11SignatureSchemes that use RSA. + tls11SignatureSchemesNumRSA = 4 +) + +func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) (*Certificate, error) { + c := hs.c + + var rsaAvail, ecdsaAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case certTypeRSASign: + rsaAvail = true + case certTypeECDSASign: + ecdsaAvail = true + } + } + + if c.config.GetClientCertificate != nil { + var signatureSchemes []SignatureScheme + + if !certReq.hasSignatureAndHash { + // Prior to TLS 1.2, the signature schemes were not + // included in the certificate request message. In this + // case we use a plausible list based on the acceptable + // certificate types. + signatureSchemes = tls11SignatureSchemes + if !ecdsaAvail { + signatureSchemes = signatureSchemes[tls11SignatureSchemesNumECDSA:] + } + if !rsaAvail { + signatureSchemes = signatureSchemes[:len(signatureSchemes)-tls11SignatureSchemesNumRSA] + } + } else { + signatureSchemes = certReq.supportedSignatureAlgorithms + } + + return c.config.GetClientCertificate(&CertificateRequestInfo{ + AcceptableCAs: certReq.certificateAuthorities, + SignatureSchemes: signatureSchemes, + }) + } + + // RFC 4346 on the certificateAuthorities field: A list of the + // distinguished names of acceptable certificate authorities. + // These distinguished names may specify a desired + // distinguished name for a root CA or for a subordinate CA; + // thus, this message can be used to describe both known roots + // and a desired authorization space. If the + // certificate_authorities list is empty then the client MAY + // send any certificate of the appropriate + // ClientCertificateType, unless there is some external + // arrangement to the contrary. + + // We need to search our list of client certs for one + // where SignatureAlgorithm is acceptable to the server and the + // Issuer is in certReq.certificateAuthorities +findCert: + for i, chain := range c.config.Certificates { + if !rsaAvail && !ecdsaAvail { + continue + } + + for j, cert := range chain.Certificate { + x509Cert := chain.Leaf + // parse the certificate if this isn't the leaf + // node, or if chain.Leaf was nil + if j != 0 || x509Cert == nil { + var err error + if x509Cert, err = x509.ParseCertificate(cert); err != nil { + c.sendAlert(alertInternalError) + return nil, errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error()) + } + } + + switch { + case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA: + case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA: + default: + continue findCert + } + + if len(certReq.certificateAuthorities) == 0 { + // they gave us an empty list, so just take the + // first cert from c.config.Certificates + return &chain, nil + } + + for _, ca := range certReq.certificateAuthorities { + if bytes.Equal(x509Cert.RawIssuer, ca) { + return &chain, nil + } + } + } + } + + // No acceptable certificate found. Don't send a certificate. + return new(Certificate), nil +} + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { + if len(config.ServerName) > 0 { + return config.ServerName + } + return serverAddr.String() +} + +// mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol +// given list of possible protocols and a list of the preference order. The +// first list must not be empty. It returns the resulting protocol and flag +// indicating if the fallback case was reached. +func mutualProtocol(protos, preferenceProtos []string) (string, bool) { + for _, s := range preferenceProtos { + for _, c := range protos { + if s == c { + return s, false + } + } + } + + return protos[0], true +} + +// hostnameInSNI converts name into an appropriate hostname for SNI. +// Literal IP addresses and absolute FQDNs are not permitted as SNI values. +// See https://tools.ietf.org/html/rfc6066#section-3. +func hostnameInSNI(name string) string { + host := name + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } + if i := strings.LastIndex(host, "%"); i > 0 { + host = host[:i] + } + if net.ParseIP(host) != nil { + return "" + } + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return name +} diff --git a/vendor/github.com/marten-seemann/qtls/handshake_messages.go b/vendor/github.com/marten-seemann/qtls/handshake_messages.go new file mode 100644 index 00000000..b060e11f --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/handshake_messages.go @@ -0,0 +1,2781 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "encoding/binary" + "strings" +) + +// signAlgosCertList helper function returns either list of signature algorithms in case +// signature_algorithms_cert extension should be marshalled or nil in the other case. +// signAlgos is a list of algorithms from signature_algorithms extension. signAlgosCert is a list +// of algorithms from signature_algorithms_cert extension. +func signAlgosCertList(signAlgos, signAlgosCert []SignatureScheme) []SignatureScheme { + if eqSignatureAlgorithms(signAlgos, signAlgosCert) { + // ensure that only supported_algorithms extension is send if supported_algorithms_cert + // has identical content + return nil + } + return signAlgosCert +} + +type clientHelloMsg struct { + raw []byte + rawTruncated []byte // for PSK binding + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + nextProtoNeg bool + serverName string + ocspStapling bool + scts bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + secureRenegotiation []byte + secureRenegotiationSupported bool + alpnProtocols []string + keyShares []keyShare + supportedVersions []uint16 + psks []psk + pskKeyExchangeModes []uint8 + earlyData bool + delegatedCredential bool + extendedMSSupported bool // RFC7627 + additionalExtensions []Extension +} + +// Function used for signature_algorithms and signature_algorithrms_cert +// extensions only (for more details, see TLS 1.3 draft 28, 4.2.3). +// +// It advances data slice and returns it, so that it can be used for further +// processing +func marshalExtensionSignatureAlgorithms(extension uint16, data []byte, schemes []SignatureScheme) []byte { + algNum := uint16(len(schemes)) + if algNum == 0 { + return data + } + + binary.BigEndian.PutUint16(data, extension) + data = data[2:] + binary.BigEndian.PutUint16(data, (2*algNum)+2) // +1 for length + data = data[2:] + binary.BigEndian.PutUint16(data, (2 * algNum)) + data = data[2:] + + for _, algo := range schemes { + binary.BigEndian.PutUint16(data, uint16(algo)) + data = data[2:] + } + return data +} + +// Function used for unmarshalling signature_algorithms or signature_algorithms_cert extensions only +// (for more details, see TLS 1.3 draft 28, 4.2.3) +// In case of error function returns alertDecoderError otherwise filled SignatureScheme slice and alertSuccess +func unmarshalExtensionSignatureAlgorithms(data []byte, length int) ([]SignatureScheme, alert) { + + if length < 2 || length&1 != 0 { + return nil, alertDecodeError + } + + algLen := binary.BigEndian.Uint16(data) + idx := 2 + + if int(algLen) != length-2 { + return nil, alertDecodeError + } + + schemes := make([]SignatureScheme, algLen/2) + for i := range schemes { + schemes[i] = SignatureScheme(binary.BigEndian.Uint16(data[idx:])) + idx += 2 + } + return schemes, alertSuccess +} + +func (m *clientHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*clientHelloMsg) + if !ok { + return false + } + + if len(m.additionalExtensions) != len(m1.additionalExtensions) { + return false + } + for i, ex := range m.additionalExtensions { + ex1 := m1.additionalExtensions[i] + if ex.Type != ex1.Type || !bytes.Equal(ex.Data, ex1.Data) { + return false + } + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + eqUint16s(m.cipherSuites, m1.cipherSuites) && + bytes.Equal(m.compressionMethods, m1.compressionMethods) && + m.nextProtoNeg == m1.nextProtoNeg && + m.serverName == m1.serverName && + m.ocspStapling == m1.ocspStapling && + m.scts == m1.scts && + eqCurveIDs(m.supportedCurves, m1.supportedCurves) && + bytes.Equal(m.supportedPoints, m1.supportedPoints) && + m.ticketSupported == m1.ticketSupported && + bytes.Equal(m.sessionTicket, m1.sessionTicket) && + eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) && + eqSignatureAlgorithms(m.supportedSignatureAlgorithmsCert, m1.supportedSignatureAlgorithmsCert) && + m.secureRenegotiationSupported == m1.secureRenegotiationSupported && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + eqStrings(m.alpnProtocols, m1.alpnProtocols) && + eqKeyShares(m.keyShares, m1.keyShares) && + eqUint16s(m.supportedVersions, m1.supportedVersions) && + m.earlyData == m1.earlyData && + m.delegatedCredential == m1.delegatedCredential && + m.extendedMSSupported == m1.extendedMSSupported +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) + numExtensions := 0 + extensionsLength := 0 + + if m.nextProtoNeg { + numExtensions++ + } + if m.ocspStapling { + extensionsLength += 1 + 2 + 2 + numExtensions++ + } + if len(m.serverName) > 0 { + extensionsLength += 5 + len(m.serverName) + numExtensions++ + } + if len(m.supportedCurves) > 0 { + extensionsLength += 2 + 2*len(m.supportedCurves) + numExtensions++ + } + if len(m.supportedPoints) > 0 { + extensionsLength += 1 + len(m.supportedPoints) + numExtensions++ + } + if m.ticketSupported { + extensionsLength += len(m.sessionTicket) + numExtensions++ + } + if len(m.supportedSignatureAlgorithms) > 0 { + extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms) + numExtensions++ + } + if m.getSignatureAlgorithmsCert() != nil { + extensionsLength += 2 + 2*len(m.getSignatureAlgorithmsCert()) + numExtensions++ + } + if m.secureRenegotiationSupported { + extensionsLength += 1 + len(m.secureRenegotiation) + numExtensions++ + } + if len(m.alpnProtocols) > 0 { + extensionsLength += 2 + for _, s := range m.alpnProtocols { + if l := len(s); l == 0 || l > 255 { + panic("invalid ALPN protocol") + } + extensionsLength++ + extensionsLength += len(s) + } + numExtensions++ + } + if m.scts { + numExtensions++ + } + if len(m.keyShares) > 0 { + extensionsLength += 2 + for _, k := range m.keyShares { + extensionsLength += 4 + len(k.data) + } + numExtensions++ + } + if len(m.supportedVersions) > 0 { + extensionsLength += 1 + 2*len(m.supportedVersions) + numExtensions++ + } + if m.earlyData { + numExtensions++ + } + if m.delegatedCredential { + numExtensions++ + } + if m.extendedMSSupported { + numExtensions++ + } + if len(m.additionalExtensions) > 0 { + numExtensions += len(m.additionalExtensions) + for _, ex := range m.additionalExtensions { + extensionsLength += len(ex.Data) + } + } + if numExtensions > 0 { + extensionsLength += 4 * numExtensions + length += 2 + extensionsLength + } + + x := make([]byte, 4+length) + x[0] = typeClientHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = uint8(m.vers >> 8) + x[5] = uint8(m.vers) + copy(x[6:38], m.random) + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + y := x[39+len(m.sessionId):] + y[0] = uint8(len(m.cipherSuites) >> 7) + y[1] = uint8(len(m.cipherSuites) << 1) + for i, suite := range m.cipherSuites { + y[2+i*2] = uint8(suite >> 8) + y[3+i*2] = uint8(suite) + } + z := y[2+len(m.cipherSuites)*2:] + z[0] = uint8(len(m.compressionMethods)) + copy(z[1:], m.compressionMethods) + + z = z[1+len(m.compressionMethods):] + if numExtensions > 0 { + z[0] = byte(extensionsLength >> 8) + z[1] = byte(extensionsLength) + z = z[2:] + } + if m.nextProtoNeg { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + // The length is always 0 + z = z[4:] + } + if len(m.serverName) > 0 { + z[0] = byte(extensionServerName >> 8) + z[1] = byte(extensionServerName & 0xff) + l := len(m.serverName) + 5 + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + + // RFC 3546, section 3.1 + // + // struct { + // NameType name_type; + // select (name_type) { + // case host_name: HostName; + // } name; + // } ServerName; + // + // enum { + // host_name(0), (255) + // } NameType; + // + // opaque HostName<1..2^16-1>; + // + // struct { + // ServerName server_name_list<1..2^16-1> + // } ServerNameList; + + z[0] = byte((len(m.serverName) + 3) >> 8) + z[1] = byte(len(m.serverName) + 3) + z[3] = byte(len(m.serverName) >> 8) + z[4] = byte(len(m.serverName)) + copy(z[5:], []byte(m.serverName)) + z = z[l:] + } + if m.ocspStapling { + // RFC 4366, section 3.6 + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z[2] = 0 + z[3] = 5 + z[4] = 1 // OCSP type + // Two zero valued uint16s for the two lengths. + z = z[9:] + } + if len(m.supportedCurves) > 0 { + // http://tools.ietf.org/html/rfc4492#section-5.5.1 + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4 + z[0] = byte(extensionSupportedCurves >> 8) + z[1] = byte(extensionSupportedCurves) + l := 2 + 2*len(m.supportedCurves) + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + z = z[6:] + for _, curve := range m.supportedCurves { + z[0] = byte(curve >> 8) + z[1] = byte(curve) + z = z[2:] + } + } + if len(m.supportedPoints) > 0 { + // http://tools.ietf.org/html/rfc4492#section-5.5.2 + z[0] = byte(extensionSupportedPoints >> 8) + z[1] = byte(extensionSupportedPoints) + l := 1 + len(m.supportedPoints) + z[2] = byte(l >> 8) + z[3] = byte(l) + l-- + z[4] = byte(l) + z = z[5:] + for _, pointFormat := range m.supportedPoints { + z[0] = pointFormat + z = z[1:] + } + } + if m.ticketSupported { + // http://tools.ietf.org/html/rfc5077#section-3.2 + z[0] = byte(extensionSessionTicket >> 8) + z[1] = byte(extensionSessionTicket) + l := len(m.sessionTicket) + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + copy(z, m.sessionTicket) + z = z[len(m.sessionTicket):] + } + + if len(m.supportedSignatureAlgorithms) > 0 { + z = marshalExtensionSignatureAlgorithms(extensionSignatureAlgorithms, z, m.supportedSignatureAlgorithms) + } + if m.getSignatureAlgorithmsCert() != nil { + // Ensure only one list of algorithms is sent if supported_algorithms and supported_algorithms_cert are the same + z = marshalExtensionSignatureAlgorithms(extensionSignatureAlgorithmsCert, z, m.getSignatureAlgorithmsCert()) + } + + if m.secureRenegotiationSupported { + z[0] = byte(extensionRenegotiationInfo >> 8) + z[1] = byte(extensionRenegotiationInfo & 0xff) + z[2] = 0 + z[3] = byte(len(m.secureRenegotiation) + 1) + z[4] = byte(len(m.secureRenegotiation)) + z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] + } + if len(m.alpnProtocols) > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + lengths := z[2:] + z = z[6:] + + stringsLength := 0 + for _, s := range m.alpnProtocols { + l := len(s) + z[0] = byte(l) + copy(z[1:], s) + z = z[1+l:] + stringsLength += 1 + l + } + + lengths[2] = byte(stringsLength >> 8) + lengths[3] = byte(stringsLength) + stringsLength += 2 + lengths[0] = byte(stringsLength >> 8) + lengths[1] = byte(stringsLength) + } + if m.scts { + // https://tools.ietf.org/html/rfc6962#section-3.3.1 + z[0] = byte(extensionSCT >> 8) + z[1] = byte(extensionSCT) + // zero uint16 for the zero-length extension_data + z = z[4:] + } + if len(m.keyShares) > 0 { + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5 + z[0] = byte(extensionKeyShare >> 8) + z[1] = byte(extensionKeyShare) + lengths := z[2:] + z = z[6:] + + totalLength := 0 + for _, ks := range m.keyShares { + z[0] = byte(ks.group >> 8) + z[1] = byte(ks.group) + z[2] = byte(len(ks.data) >> 8) + z[3] = byte(len(ks.data)) + copy(z[4:], ks.data) + z = z[4+len(ks.data):] + totalLength += 4 + len(ks.data) + } + + lengths[2] = byte(totalLength >> 8) + lengths[3] = byte(totalLength) + totalLength += 2 + lengths[0] = byte(totalLength >> 8) + lengths[1] = byte(totalLength) + } + if len(m.supportedVersions) > 0 { + z[0] = byte(extensionSupportedVersions >> 8) + z[1] = byte(extensionSupportedVersions) + l := 1 + 2*len(m.supportedVersions) + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 1 + z[4] = byte(l) + z = z[5:] + for _, v := range m.supportedVersions { + z[0] = byte(v >> 8) + z[1] = byte(v) + z = z[2:] + } + } + if m.earlyData { + z[0] = byte(extensionEarlyData >> 8) + z[1] = byte(extensionEarlyData) + z = z[4:] + } + if m.delegatedCredential { + binary.BigEndian.PutUint16(z, extensionDelegatedCredential) + z = z[4:] + } + if m.extendedMSSupported { + binary.BigEndian.PutUint16(z, extensionEMS) + z = z[4:] + } + for _, ex := range m.additionalExtensions { + z[0] = byte(ex.Type >> 8) + z[1] = byte(ex.Type) + l := len(ex.Data) + z[2] = byte(l >> 8) + z[3] = byte(l) + copy(z[4:], ex.Data) + z = z[4+l:] + } + + m.raw = x + + return x +} + +func (m *clientHelloMsg) unmarshal(data []byte) alert { + if len(data) < 42 { + return alertDecodeError + } + m.raw = data + m.vers = uint16(data[4])<<8 | uint16(data[5]) + m.random = data[6:38] + sessionIdLen := int(data[38]) + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return alertDecodeError + } + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] + bindersOffset := 39 + sessionIdLen + if len(data) < 2 { + return alertDecodeError + } + // cipherSuiteLen is the number of bytes of cipher suite numbers. Since + // they are uint16s, the number must be even. + cipherSuiteLen := int(data[0])<<8 | int(data[1]) + if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { + return alertDecodeError + } + numCipherSuites := cipherSuiteLen / 2 + m.cipherSuites = make([]uint16, numCipherSuites) + for i := 0; i < numCipherSuites; i++ { + m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) + if m.cipherSuites[i] == scsvRenegotiation { + m.secureRenegotiationSupported = true + } + } + data = data[2+cipherSuiteLen:] + bindersOffset += 2 + cipherSuiteLen + if len(data) < 1 { + return alertDecodeError + } + compressionMethodsLen := int(data[0]) + if len(data) < 1+compressionMethodsLen { + return alertDecodeError + } + m.compressionMethods = data[1 : 1+compressionMethodsLen] + + data = data[1+compressionMethodsLen:] + bindersOffset += 1 + compressionMethodsLen + + m.nextProtoNeg = false + m.serverName = "" + m.ocspStapling = false + m.ticketSupported = false + m.sessionTicket = nil + m.supportedSignatureAlgorithms = nil + m.alpnProtocols = nil + m.scts = false + m.keyShares = nil + m.supportedVersions = nil + m.psks = nil + m.pskKeyExchangeModes = nil + m.earlyData = false + m.delegatedCredential = false + m.extendedMSSupported = false + + if len(data) == 0 { + // ClientHello is optionally followed by extension data + return alertSuccess + } + if len(data) < 2 { + return alertDecodeError + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + bindersOffset += 2 + if extensionsLength != len(data) { + return alertDecodeError + } + + for len(data) != 0 { + if len(data) < 4 { + return alertDecodeError + } + ext := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + bindersOffset += 4 + if len(data) < length { + return alertDecodeError + } + + switch ext { + case extensionServerName: + d := data[:length] + if len(d) < 2 { + return alertDecodeError + } + namesLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != namesLen { + return alertDecodeError + } + for len(d) > 0 { + if len(d) < 3 { + return alertDecodeError + } + nameType := d[0] + nameLen := int(d[1])<<8 | int(d[2]) + d = d[3:] + if len(d) < nameLen { + return alertDecodeError + } + if nameType == 0 { + m.serverName = string(d[:nameLen]) + // An SNI value may not include a + // trailing dot. See + // https://tools.ietf.org/html/rfc6066#section-3. + if strings.HasSuffix(m.serverName, ".") { + // TODO use alertDecodeError? + return alertUnexpectedMessage + } + break + } + d = d[nameLen:] + } + case extensionNextProtoNeg: + if length > 0 { + return alertDecodeError + } + m.nextProtoNeg = true + case extensionStatusRequest: + m.ocspStapling = length > 0 && data[0] == statusTypeOCSP + case extensionSupportedCurves: + // http://tools.ietf.org/html/rfc4492#section-5.5.1 + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.4 + if length < 2 { + return alertDecodeError + } + l := int(data[0])<<8 | int(data[1]) + if l%2 == 1 || length != l+2 { + return alertDecodeError + } + numCurves := l / 2 + m.supportedCurves = make([]CurveID, numCurves) + d := data[2:] + for i := 0; i < numCurves; i++ { + m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1]) + d = d[2:] + } + case extensionSupportedPoints: + // http://tools.ietf.org/html/rfc4492#section-5.5.2 + if length < 1 { + return alertDecodeError + } + l := int(data[0]) + if length != l+1 { + return alertDecodeError + } + m.supportedPoints = make([]uint8, l) + copy(m.supportedPoints, data[1:]) + case extensionSessionTicket: + // http://tools.ietf.org/html/rfc5077#section-3.2 + m.ticketSupported = true + m.sessionTicket = data[:length] + case extensionSignatureAlgorithms: + // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3 + if length < 2 || length&1 != 0 { + return alertDecodeError + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return alertDecodeError + } + n := l / 2 + d := data[2:] + m.supportedSignatureAlgorithms = make([]SignatureScheme, n) + for i := range m.supportedSignatureAlgorithms { + m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1]) + d = d[2:] + } + case extensionRenegotiationInfo: + if length == 0 { + return alertDecodeError + } + d := data[:length] + l := int(d[0]) + d = d[1:] + if l != len(d) { + return alertDecodeError + } + + m.secureRenegotiation = d + m.secureRenegotiationSupported = true + case extensionALPN: + if length < 2 { + return alertDecodeError + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return alertDecodeError + } + d := data[2:length] + for len(d) != 0 { + stringLen := int(d[0]) + d = d[1:] + if stringLen == 0 || stringLen > len(d) { + return alertDecodeError + } + m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) + d = d[stringLen:] + } + case extensionSCT: + m.scts = true + if length != 0 { + return alertDecodeError + } + case extensionKeyShare: + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.5 + if length < 2 { + return alertDecodeError + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return alertDecodeError + } + d := data[2:length] + for len(d) != 0 { + if len(d) < 4 { + return alertDecodeError + } + dataLen := int(d[2])<<8 | int(d[3]) + if dataLen == 0 || 4+dataLen > len(d) { + return alertDecodeError + } + m.keyShares = append(m.keyShares, keyShare{ + group: CurveID(d[0])<<8 | CurveID(d[1]), + data: d[4 : 4+dataLen], + }) + d = d[4+dataLen:] + } + case extensionSupportedVersions: + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.1 + if length < 1 { + return alertDecodeError + } + l := int(data[0]) + if l%2 == 1 || length != l+1 { + return alertDecodeError + } + n := l / 2 + d := data[1:] + for i := 0; i < n; i++ { + v := uint16(d[0])<<8 + uint16(d[1]) + m.supportedVersions = append(m.supportedVersions, v) + d = d[2:] + } + case extensionPreSharedKey: + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6 + if length < 2 { + return alertDecodeError + } + // Ensure this extension is the last one in the Client Hello + if len(data) != length { + return alertIllegalParameter + } + li := int(data[0])<<8 | int(data[1]) + if 2+li+2 > length { + return alertDecodeError + } + d := data[2 : 2+li] + bindersOffset += 2 + li + for len(d) > 0 { + if len(d) < 6 { + return alertDecodeError + } + l := int(d[0])<<8 | int(d[1]) + if len(d) < 2+l+4 { + return alertDecodeError + } + m.psks = append(m.psks, psk{ + identity: d[2 : 2+l], + obfTicketAge: uint32(d[l+2])<<24 | uint32(d[l+3])<<16 | + uint32(d[l+4])<<8 | uint32(d[l+5]), + }) + d = d[2+l+4:] + } + lb := int(data[li+2])<<8 | int(data[li+3]) + d = data[2+li+2:] + if lb != len(d) || lb == 0 { + return alertDecodeError + } + i := 0 + for len(d) > 0 { + if i >= len(m.psks) { + return alertIllegalParameter + } + if len(d) < 1 { + return alertDecodeError + } + l := int(d[0]) + if l > len(d)-1 { + return alertDecodeError + } + if i >= len(m.psks) { + return alertIllegalParameter + } + m.psks[i].binder = d[1 : 1+l] + d = d[1+l:] + i++ + } + if i != len(m.psks) { + return alertIllegalParameter + } + m.rawTruncated = m.raw[:bindersOffset] + case extensionPSKKeyExchangeModes: + if length < 2 { + return alertDecodeError + } + l := int(data[0]) + if length != l+1 { + return alertDecodeError + } + m.pskKeyExchangeModes = data[1:length] + case extensionEarlyData: + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8 + m.earlyData = true + case extensionDelegatedCredential: + // https://tools.ietf.org/html/draft-ietf-tls-subcerts-02 + m.delegatedCredential = true + case extensionEMS: + // RFC 7627 + m.extendedMSSupported = true + if length != 0 { + return alertDecodeError + } + default: + m.additionalExtensions = append(m.additionalExtensions, + Extension{Type: ext, Data: data[:length]}) + } + data = data[length:] + bindersOffset += length + } + + return alertSuccess +} + +func (m *clientHelloMsg) getSignatureAlgorithmsCert() []SignatureScheme { + return signAlgosCertList(m.supportedSignatureAlgorithms, m.supportedSignatureAlgorithmsCert) +} + +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + nextProtoNeg bool + nextProtos []string + ocspStapling bool + scts [][]byte + ticketSupported bool + secureRenegotiation []byte + secureRenegotiationSupported bool + alpnProtocol string + + // TLS 1.3 + keyShare keyShare + psk bool + pskIdentity uint16 + + // RFC7627 + extendedMSSupported bool +} + +func (m *serverHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*serverHelloMsg) + if !ok { + return false + } + + if len(m.scts) != len(m1.scts) { + return false + } + for i, sct := range m.scts { + if !bytes.Equal(sct, m1.scts[i]) { + return false + } + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + m.cipherSuite == m1.cipherSuite && + m.compressionMethod == m1.compressionMethod && + m.nextProtoNeg == m1.nextProtoNeg && + eqStrings(m.nextProtos, m1.nextProtos) && + m.ocspStapling == m1.ocspStapling && + m.ticketSupported == m1.ticketSupported && + m.secureRenegotiationSupported == m1.secureRenegotiationSupported && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + m.alpnProtocol == m1.alpnProtocol && + m.keyShare.group == m1.keyShare.group && + bytes.Equal(m.keyShare.data, m1.keyShare.data) && + m.psk == m1.psk && + m.pskIdentity == m1.pskIdentity && + m.extendedMSSupported == m1.extendedMSSupported +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 38 + len(m.sessionId) + numExtensions := 0 + extensionsLength := 0 + + nextProtoLen := 0 + if m.nextProtoNeg { + numExtensions++ + for _, v := range m.nextProtos { + nextProtoLen += len(v) + } + nextProtoLen += len(m.nextProtos) + extensionsLength += nextProtoLen + } + if m.ocspStapling { + numExtensions++ + } + if m.ticketSupported { + numExtensions++ + } + if m.secureRenegotiationSupported { + extensionsLength += 1 + len(m.secureRenegotiation) + numExtensions++ + } + if m.extendedMSSupported { + numExtensions++ + } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + if alpnLen >= 256 { + panic("invalid ALPN protocol") + } + extensionsLength += 2 + 1 + alpnLen + numExtensions++ + } + sctLen := 0 + if len(m.scts) > 0 { + for _, sct := range m.scts { + sctLen += len(sct) + 2 + } + extensionsLength += 2 + sctLen + numExtensions++ + } + if m.keyShare.group != 0 { + extensionsLength += 4 + len(m.keyShare.data) + numExtensions++ + } + if m.psk { + extensionsLength += 2 + numExtensions++ + } + // supported_versions extension + if m.vers >= VersionTLS13 { + extensionsLength += 2 + numExtensions++ + } + + if numExtensions > 0 { + extensionsLength += 4 * numExtensions + length += 2 + extensionsLength + } + + x := make([]byte, 4+length) + x[0] = typeServerHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + if m.vers >= VersionTLS13 { + x[4] = 3 + x[5] = 3 + } else { + x[4] = uint8(m.vers >> 8) + x[5] = uint8(m.vers) + } + copy(x[6:38], m.random) + z := x[38:] + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + z = x[39+len(m.sessionId):] + z[0] = uint8(m.cipherSuite >> 8) + z[1] = uint8(m.cipherSuite) + z[2] = m.compressionMethod + z = z[3:] + + if numExtensions > 0 { + z[0] = byte(extensionsLength >> 8) + z[1] = byte(extensionsLength) + z = z[2:] + } + if m.vers >= VersionTLS13 { + z[0] = byte(extensionSupportedVersions >> 8) + z[1] = byte(extensionSupportedVersions) + z[3] = 2 + z[4] = uint8(m.vers >> 8) + z[5] = uint8(m.vers) + z = z[6:] + } + if m.nextProtoNeg { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + z[2] = byte(nextProtoLen >> 8) + z[3] = byte(nextProtoLen) + z = z[4:] + + for _, v := range m.nextProtos { + l := len(v) + if l > 255 { + l = 255 + } + z[0] = byte(l) + copy(z[1:], []byte(v[0:l])) + z = z[1+l:] + } + } + if m.ocspStapling { + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z = z[4:] + } + if m.ticketSupported { + z[0] = byte(extensionSessionTicket >> 8) + z[1] = byte(extensionSessionTicket) + z = z[4:] + } + if m.secureRenegotiationSupported { + z[0] = byte(extensionRenegotiationInfo >> 8) + z[1] = byte(extensionRenegotiationInfo & 0xff) + z[2] = 0 + z[3] = byte(len(m.secureRenegotiation) + 1) + z[4] = byte(len(m.secureRenegotiation)) + z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] + } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + l := 2 + 1 + alpnLen + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + l -= 1 + z[6] = byte(l) + copy(z[7:], []byte(m.alpnProtocol)) + z = z[7+alpnLen:] + } + if sctLen > 0 { + z[0] = byte(extensionSCT >> 8) + z[1] = byte(extensionSCT) + l := sctLen + 2 + z[2] = byte(l >> 8) + z[3] = byte(l) + z[4] = byte(sctLen >> 8) + z[5] = byte(sctLen) + + z = z[6:] + for _, sct := range m.scts { + z[0] = byte(len(sct) >> 8) + z[1] = byte(len(sct)) + copy(z[2:], sct) + z = z[len(sct)+2:] + } + } + if m.keyShare.group != 0 { + z[0] = uint8(extensionKeyShare >> 8) + z[1] = uint8(extensionKeyShare) + l := 4 + len(m.keyShare.data) + z[2] = uint8(l >> 8) + z[3] = uint8(l) + z[4] = uint8(m.keyShare.group >> 8) + z[5] = uint8(m.keyShare.group) + l -= 4 + z[6] = uint8(l >> 8) + z[7] = uint8(l) + copy(z[8:], m.keyShare.data) + z = z[8+l:] + } + + if m.psk { + z[0] = byte(extensionPreSharedKey >> 8) + z[1] = byte(extensionPreSharedKey) + z[3] = 2 + z[4] = byte(m.pskIdentity >> 8) + z[5] = byte(m.pskIdentity) + z = z[6:] + } + if m.extendedMSSupported { + binary.BigEndian.PutUint16(z, extensionEMS) + z = z[4:] + } + + m.raw = x + + return x +} + +func (m *serverHelloMsg) unmarshal(data []byte) alert { + if len(data) < 42 { + return alertDecodeError + } + m.raw = data + m.vers = uint16(data[4])<<8 | uint16(data[5]) + m.random = data[6:38] + sessionIdLen := int(data[38]) + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return alertDecodeError + } + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] + if len(data) < 3 { + return alertDecodeError + } + m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) + m.compressionMethod = data[2] + data = data[3:] + + m.nextProtoNeg = false + m.nextProtos = nil + m.ocspStapling = false + m.scts = nil + m.ticketSupported = false + m.alpnProtocol = "" + m.keyShare.group = 0 + m.keyShare.data = nil + m.psk = false + m.pskIdentity = 0 + m.extendedMSSupported = false + + if len(data) == 0 { + // ServerHello is optionally followed by extension data + return alertSuccess + } + if len(data) < 2 { + return alertDecodeError + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if len(data) != extensionsLength { + return alertDecodeError + } + + svData := findExtension(data, extensionSupportedVersions) + if svData != nil { + if len(svData) != 2 { + return alertDecodeError + } + if m.vers != VersionTLS12 { + return alertDecodeError + } + rcvVer := binary.BigEndian.Uint16(svData[0:]) + if rcvVer < VersionTLS13 { + return alertIllegalParameter + } + m.vers = rcvVer + } + + for len(data) != 0 { + if len(data) < 4 { + return alertDecodeError + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return alertDecodeError + } + + switch extension { + case extensionNextProtoNeg: + m.nextProtoNeg = true + d := data[:length] + for len(d) > 0 { + l := int(d[0]) + d = d[1:] + if l == 0 || l > len(d) { + return alertDecodeError + } + m.nextProtos = append(m.nextProtos, string(d[:l])) + d = d[l:] + } + case extensionStatusRequest: + if length > 0 { + return alertDecodeError + } + m.ocspStapling = true + case extensionSessionTicket: + if length > 0 { + return alertDecodeError + } + m.ticketSupported = true + case extensionRenegotiationInfo: + if length == 0 { + return alertDecodeError + } + d := data[:length] + l := int(d[0]) + d = d[1:] + if l != len(d) { + return alertDecodeError + } + + m.secureRenegotiation = d + m.secureRenegotiationSupported = true + case extensionALPN: + d := data[:length] + if len(d) < 3 { + return alertDecodeError + } + l := int(d[0])<<8 | int(d[1]) + if l != len(d)-2 { + return alertDecodeError + } + d = d[2:] + l = int(d[0]) + if l != len(d)-1 { + return alertDecodeError + } + d = d[1:] + if len(d) == 0 { + // ALPN protocols must not be empty. + return alertDecodeError + } + m.alpnProtocol = string(d) + case extensionSCT: + d := data[:length] + + if len(d) < 2 { + return alertDecodeError + } + l := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != l || l == 0 { + return alertDecodeError + } + + m.scts = make([][]byte, 0, 3) + for len(d) != 0 { + if len(d) < 2 { + return alertDecodeError + } + sctLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if sctLen == 0 || len(d) < sctLen { + return alertDecodeError + } + m.scts = append(m.scts, d[:sctLen]) + d = d[sctLen:] + } + case extensionKeyShare: + d := data[:length] + + if len(d) < 4 { + return alertDecodeError + } + m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1]) + l := int(d[2])<<8 | int(d[3]) + d = d[4:] + if len(d) != l { + return alertDecodeError + } + m.keyShare.data = d[:l] + case extensionPreSharedKey: + if length != 2 { + return alertDecodeError + } + m.psk = true + m.pskIdentity = uint16(data[0])<<8 | uint16(data[1]) + case extensionEMS: + m.extendedMSSupported = true + } + data = data[length:] + } + + return alertSuccess +} + +type encryptedExtensionsMsg struct { + raw []byte + alpnProtocol string + earlyData bool + + additionalExtensions []Extension +} + +func (m *encryptedExtensionsMsg) equal(i interface{}) bool { + m1, ok := i.(*encryptedExtensionsMsg) + if !ok { + return false + } + + if len(m.additionalExtensions) != len(m1.additionalExtensions) { + return false + } + for i, ex := range m.additionalExtensions { + ex1 := m1.additionalExtensions[i] + if ex.Type != ex1.Type || !bytes.Equal(ex.Data, ex1.Data) { + return false + } + } + + return bytes.Equal(m.raw, m1.raw) && + m.alpnProtocol == m1.alpnProtocol && + m.earlyData == m1.earlyData +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 2 + + if m.earlyData { + length += 4 + } + alpnLen := len(m.alpnProtocol) + if alpnLen > 0 { + if alpnLen >= 256 { + panic("invalid ALPN protocol") + } + length += 2 + 2 + 2 + 1 + alpnLen + } + if len(m.additionalExtensions) > 0 { + for _, ex := range m.additionalExtensions { + length += 4 + len(ex.Data) + } + } + + x := make([]byte, 4+length) + x[0] = typeEncryptedExtensions + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + length -= 2 + x[4] = uint8(length >> 8) + x[5] = uint8(length) + + z := x[6:] + if alpnLen > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN) + l := 2 + 1 + alpnLen + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + l -= 1 + z[6] = byte(l) + copy(z[7:], []byte(m.alpnProtocol)) + z = z[7+alpnLen:] + } + + if m.earlyData { + z[0] = byte(extensionEarlyData >> 8) + z[1] = byte(extensionEarlyData) + z = z[4:] + } + + for _, ex := range m.additionalExtensions { + z[0] = byte(ex.Type >> 8) + z[1] = byte(ex.Type) + l := len(ex.Data) + z[2] = byte(l >> 8) + z[3] = byte(l) + copy(z[4:], ex.Data) + z = z[4+l:] + } + + m.raw = x + return x +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) alert { + if len(data) < 6 { + return alertDecodeError + } + m.raw = data + + m.alpnProtocol = "" + m.earlyData = false + + extensionsLength := int(data[4])<<8 | int(data[5]) + data = data[6:] + if len(data) != extensionsLength { + return alertDecodeError + } + + for len(data) != 0 { + if len(data) < 4 { + return alertDecodeError + } + ext := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return alertDecodeError + } + + switch ext { + case extensionALPN: + d := data[:length] + if len(d) < 3 { + return alertDecodeError + } + l := int(d[0])<<8 | int(d[1]) + if l != len(d)-2 { + return alertDecodeError + } + d = d[2:] + l = int(d[0]) + if l != len(d)-1 { + return alertDecodeError + } + d = d[1:] + if len(d) == 0 { + // ALPN protocols must not be empty. + return alertDecodeError + } + m.alpnProtocol = string(d) + case extensionEarlyData: + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8 + m.earlyData = true + default: + m.additionalExtensions = append(m.additionalExtensions, + Extension{Type: ext, Data: data[:length]}) + } + + data = data[length:] + } + + return alertSuccess +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + eqByteSlices(m.certificates, m1.certificates) +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) alert { + if len(data) < 7 { + return alertDecodeError + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return alertDecodeError + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return alertDecodeError + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return alertDecodeError + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return alertSuccess +} + +type certificateEntry struct { + data []byte + ocspStaple []byte + sctList [][]byte + delegatedCredential []byte +} + +type certificateMsg13 struct { + raw []byte + requestContext []byte + certificates []certificateEntry +} + +func (m *certificateMsg13) equal(i interface{}) bool { + m1, ok := i.(*certificateMsg13) + if !ok { + return false + } + + if len(m.certificates) != len(m1.certificates) { + return false + } + for i, _ := range m.certificates { + ok := bytes.Equal(m.certificates[i].data, m1.certificates[i].data) + ok = ok && bytes.Equal(m.certificates[i].ocspStaple, m1.certificates[i].ocspStaple) + ok = ok && eqByteSlices(m.certificates[i].sctList, m1.certificates[i].sctList) + ok = ok && bytes.Equal(m.certificates[i].delegatedCredential, m1.certificates[i].delegatedCredential) + if !ok { + return false + } + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.requestContext, m1.requestContext) +} + +func (m *certificateMsg13) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, cert := range m.certificates { + i += len(cert.data) + if len(cert.ocspStaple) != 0 { + i += 8 + len(cert.ocspStaple) + } + if len(cert.sctList) != 0 { + i += 6 + for _, sct := range cert.sctList { + i += 2 + len(sct) + } + } + if len(cert.delegatedCredential) != 0 { + i += 4 + len(cert.delegatedCredential) + } + } + + length := 3 + 3*len(m.certificates) + i + length += 2 * len(m.certificates) // extensions + length += 1 + len(m.requestContext) + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + z := x[4:] + + z[0] = byte(len(m.requestContext)) + copy(z[1:], m.requestContext) + z = z[1+len(m.requestContext):] + + certificateOctets := len(z) - 3 + z[0] = uint8(certificateOctets >> 16) + z[1] = uint8(certificateOctets >> 8) + z[2] = uint8(certificateOctets) + + z = z[3:] + for _, cert := range m.certificates { + z[0] = uint8(len(cert.data) >> 16) + z[1] = uint8(len(cert.data) >> 8) + z[2] = uint8(len(cert.data)) + copy(z[3:], cert.data) + z = z[3+len(cert.data):] + + extLenPos := z[:2] + z = z[2:] + + extensionLen := 0 + if len(cert.ocspStaple) != 0 { + stapleLen := 4 + len(cert.ocspStaple) + z[0] = uint8(extensionStatusRequest >> 8) + z[1] = uint8(extensionStatusRequest) + z[2] = uint8(stapleLen >> 8) + z[3] = uint8(stapleLen) + + stapleLen -= 4 + z[4] = statusTypeOCSP + z[5] = uint8(stapleLen >> 16) + z[6] = uint8(stapleLen >> 8) + z[7] = uint8(stapleLen) + copy(z[8:], cert.ocspStaple) + z = z[8+stapleLen:] + + extensionLen += 8 + stapleLen + } + if len(cert.sctList) != 0 { + z[0] = uint8(extensionSCT >> 8) + z[1] = uint8(extensionSCT) + sctLenPos := z[2:6] + z = z[6:] + extensionLen += 6 + + sctLen := 2 + for _, sct := range cert.sctList { + z[0] = uint8(len(sct) >> 8) + z[1] = uint8(len(sct)) + copy(z[2:], sct) + z = z[2+len(sct):] + + extensionLen += 2 + len(sct) + sctLen += 2 + len(sct) + } + sctLenPos[0] = uint8(sctLen >> 8) + sctLenPos[1] = uint8(sctLen) + sctLen -= 2 + sctLenPos[2] = uint8(sctLen >> 8) + sctLenPos[3] = uint8(sctLen) + } + if len(cert.delegatedCredential) != 0 { + binary.BigEndian.PutUint16(z, extensionDelegatedCredential) + binary.BigEndian.PutUint16(z[2:], uint16(len(cert.delegatedCredential))) + z = z[4:] + copy(z, cert.delegatedCredential) + z = z[len(cert.delegatedCredential):] + extensionLen += 4 + len(cert.delegatedCredential) + } + + extLenPos[0] = uint8(extensionLen >> 8) + extLenPos[1] = uint8(extensionLen) + } + + m.raw = x + return +} + +func (m *certificateMsg13) unmarshal(data []byte) alert { + if len(data) < 5 { + return alertDecodeError + } + + m.raw = data + + ctxLen := data[4] + if len(data) < int(ctxLen)+5+3 { + return alertDecodeError + } + m.requestContext = data[5 : 5+ctxLen] + + d := data[5+ctxLen:] + certsLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) != certsLen+3 { + return alertDecodeError + } + + numCerts := 0 + d = d[3:] + for certsLen > 0 { + if len(d) < 4 { + return alertDecodeError + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return alertDecodeError + } + d = d[3+certLen:] + + if len(d) < 2 { + return alertDecodeError + } + extLen := uint16(d[0])<<8 | uint16(d[1]) + if uint16(len(d)) < 2+extLen { + return alertDecodeError + } + d = d[2+extLen:] + + certsLen -= 3 + certLen + 2 + uint32(extLen) + numCerts++ + } + + m.certificates = make([]certificateEntry, numCerts) + d = data[8+ctxLen:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i].data = d[3 : 3+certLen] + d = d[3+certLen:] + + extLen := uint16(d[0])<<8 | uint16(d[1]) + d = d[2:] + for extLen > 0 { + if extLen < 4 { + return alertDecodeError + } + typ := uint16(d[0])<<8 | uint16(d[1]) + bodyLen := uint16(d[2])<<8 | uint16(d[3]) + if extLen < 4+bodyLen { + return alertDecodeError + } + body := d[4 : 4+bodyLen] + d = d[4+bodyLen:] + extLen -= 4 + bodyLen + + switch typ { + case extensionStatusRequest: + if len(body) < 4 || body[0] != 0x01 { + return alertDecodeError + } + ocspLen := int(body[1])<<16 | int(body[2])<<8 | int(body[3]) + if len(body) != 4+ocspLen { + return alertDecodeError + } + m.certificates[i].ocspStaple = body[4:] + + case extensionSCT: + if len(body) < 2 { + return alertDecodeError + } + listLen := int(body[0])<<8 | int(body[1]) + body = body[2:] + if len(body) != listLen { + return alertDecodeError + } + for len(body) > 0 { + if len(body) < 2 { + return alertDecodeError + } + sctLen := int(body[0])<<8 | int(body[1]) + if len(body) < 2+sctLen { + return alertDecodeError + } + m.certificates[i].sctList = append(m.certificates[i].sctList, body[2:2+sctLen]) + body = body[2+sctLen:] + } + case extensionDelegatedCredential: + m.certificates[i].delegatedCredential = body + } + } + } + + return alertSuccess +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*serverKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.key, m1.key) +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) alert { + m.raw = data + if len(data) < 4 { + return alertDecodeError + } + m.key = data[4:] + return alertSuccess +} + +type certificateStatusMsg struct { + raw []byte + statusType uint8 + response []byte +} + +func (m *certificateStatusMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateStatusMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.statusType == m1.statusType && + bytes.Equal(m.response, m1.response) +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var x []byte + if m.statusType == statusTypeOCSP { + x = make([]byte, 4+4+len(m.response)) + x[0] = typeCertificateStatus + l := len(m.response) + 4 + x[1] = byte(l >> 16) + x[2] = byte(l >> 8) + x[3] = byte(l) + x[4] = statusTypeOCSP + + l -= 4 + x[5] = byte(l >> 16) + x[6] = byte(l >> 8) + x[7] = byte(l) + copy(x[8:], m.response) + } else { + x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} + } + + m.raw = x + return x +} + +func (m *certificateStatusMsg) unmarshal(data []byte) alert { + m.raw = data + if len(data) < 5 { + return alertDecodeError + } + m.statusType = data[4] + + m.response = nil + if m.statusType == statusTypeOCSP { + if len(data) < 8 { + return alertDecodeError + } + respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) + if uint32(len(data)) != 4+4+respLen { + return alertDecodeError + } + m.response = data[8:] + } + return alertSuccess +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) equal(i interface{}) bool { + _, ok := i.(*serverHelloDoneMsg) + return ok +} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) alert { + if len(data) != 4 { + return alertDecodeError + } + return alertSuccess +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*clientKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ciphertext, m1.ciphertext) +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) alert { + m.raw = data + if len(data) < 4 { + return alertDecodeError + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return alertDecodeError + } + m.ciphertext = data[4:] + return alertSuccess +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) equal(i interface{}) bool { + m1, ok := i.(*finishedMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.verifyData, m1.verifyData) +} + +func (m *finishedMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + x = make([]byte, 4+len(m.verifyData)) + x[0] = typeFinished + x[3] = byte(len(m.verifyData)) + copy(x[4:], m.verifyData) + m.raw = x + return +} + +func (m *finishedMsg) unmarshal(data []byte) alert { + m.raw = data + if len(data) < 4 { + return alertDecodeError + } + m.verifyData = data[4:] + return alertSuccess +} + +type nextProtoMsg struct { + raw []byte + proto string +} + +func (m *nextProtoMsg) equal(i interface{}) bool { + m1, ok := i.(*nextProtoMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.proto == m1.proto +} + +func (m *nextProtoMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + l := len(m.proto) + if l > 255 { + l = 255 + } + + padding := 32 - (l+2)%32 + length := l + padding + 2 + x := make([]byte, length+4) + x[0] = typeNextProtocol + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + y := x[4:] + y[0] = byte(l) + copy(y[1:], []byte(m.proto[0:l])) + y = y[1+l:] + y[0] = byte(padding) + + m.raw = x + + return x +} + +func (m *nextProtoMsg) unmarshal(data []byte) alert { + m.raw = data + + if len(data) < 5 { + return alertDecodeError + } + data = data[4:] + protoLen := int(data[0]) + data = data[1:] + if len(data) < protoLen { + return alertDecodeError + } + m.proto = string(data[0:protoLen]) + data = data[protoLen:] + + if len(data) < 1 { + return alertDecodeError + } + paddingLen := int(data[0]) + data = data[1:] + if len(data) != paddingLen { + return alertDecodeError + } + + return alertSuccess +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAndHash indicates whether this message includes a list + // of signature and hash functions. This change was introduced with TLS + // 1.2. + hasSignatureAndHash bool + + certificateTypes []byte + supportedSignatureAlgorithms []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateRequestMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.certificateTypes, m1.certificateTypes) && + eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && + eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc4346#section-7.4.4 + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAndHash { + length += 2 + 2*len(m.supportedSignatureAlgorithms) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAndHash { + n := len(m.supportedSignatureAlgorithms) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAlgo := range m.supportedSignatureAlgorithms { + y[0] = uint8(sigAlgo >> 8) + y[1] = uint8(sigAlgo) + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) alert { + m.raw = data + + if len(data) < 5 { + return alertDecodeError + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return alertDecodeError + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return alertDecodeError + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return alertDecodeError + } + + data = data[numCertTypes:] + + if m.hasSignatureAndHash { + if len(data) < 2 { + return alertDecodeError + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return alertDecodeError + } + if len(data) < int(sigAndHashLen) { + return alertDecodeError + } + numSigAlgos := sigAndHashLen / 2 + m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) + for i := range m.supportedSignatureAlgorithms { + m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) + data = data[2:] + } + } + + if len(data) < 2 { + return alertDecodeError + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return alertDecodeError + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return alertDecodeError + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return alertDecodeError + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + + if len(data) != 0 { + return alertDecodeError + } + + return alertSuccess +} + +type certificateRequestMsg13 struct { + raw []byte + + requestContext []byte + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg13) equal(i interface{}) bool { + m1, ok := i.(*certificateRequestMsg13) + return ok && + bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.requestContext, m1.requestContext) && + eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && + eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) && + eqSignatureAlgorithms(m.supportedSignatureAlgorithmsCert, m1.supportedSignatureAlgorithmsCert) +} + +func (m *certificateRequestMsg13) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.3.2 + length := 1 + len(m.requestContext) + numExtensions := 1 + extensionsLength := 2 + 2*len(m.supportedSignatureAlgorithms) + + if m.getSignatureAlgorithmsCert() != nil { + numExtensions += 1 + extensionsLength += 2 + 2*len(m.getSignatureAlgorithmsCert()) + } + + casLength := 0 + if len(m.certificateAuthorities) > 0 { + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + extensionsLength += 2 + casLength + numExtensions++ + } + + extensionsLength += 4 * numExtensions + length += 2 + extensionsLength + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.requestContext)) + copy(x[5:], m.requestContext) + z := x[5+len(m.requestContext):] + + z[0] = byte(extensionsLength >> 8) + z[1] = byte(extensionsLength) + z = z[2:] + + // TODO: this function should be reused by CH + z = marshalExtensionSignatureAlgorithms(extensionSignatureAlgorithms, z, m.supportedSignatureAlgorithms) + + if m.getSignatureAlgorithmsCert() != nil { + z = marshalExtensionSignatureAlgorithms(extensionSignatureAlgorithmsCert, z, m.getSignatureAlgorithmsCert()) + } + // certificate_authorities + if casLength > 0 { + z[0] = byte(extensionCAs >> 8) + z[1] = byte(extensionCAs) + l := 2 + casLength + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + + z[0] = uint8(casLength >> 8) + z[1] = uint8(casLength) + z = z[2:] + for _, ca := range m.certificateAuthorities { + z[0] = uint8(len(ca) >> 8) + z[1] = uint8(len(ca)) + z = z[2:] + copy(z, ca) + z = z[len(ca):] + } + } + + m.raw = x + return +} + +func (m *certificateRequestMsg13) unmarshal(data []byte) alert { + m.raw = data + m.supportedSignatureAlgorithms = nil + m.certificateAuthorities = nil + + if len(data) < 5 { + return alertDecodeError + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return alertDecodeError + } + + ctxLen := data[4] + if len(data) < 5+int(ctxLen)+2 { + return alertDecodeError + } + m.requestContext = data[5 : 5+ctxLen] + data = data[5+ctxLen:] + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if len(data) != extensionsLength { + return alertDecodeError + } + + for len(data) != 0 { + if len(data) < 4 { + return alertDecodeError + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return alertDecodeError + } + + switch extension { + case extensionSignatureAlgorithms: + // TODO: unmarshalExtensionSignatureAlgorithms should be shared with CH and pre-1.3 CV + // https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.2.3 + var err alert + m.supportedSignatureAlgorithms, err = unmarshalExtensionSignatureAlgorithms(data, length) + if err != alertSuccess { + return err + } + case extensionSignatureAlgorithmsCert: + var err alert + m.supportedSignatureAlgorithmsCert, err = unmarshalExtensionSignatureAlgorithms(data, length) + if err != alertSuccess { + return err + } + case extensionCAs: + // TODO DRY: share code with CH + if length < 2 { + return alertDecodeError + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 || l < 3 { + return alertDecodeError + } + cas := make([]byte, l) + copy(cas, data[2:]) + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return alertDecodeError + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return alertDecodeError + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + } + data = data[length:] + } + + if len(m.supportedSignatureAlgorithms) == 0 { + return alertDecodeError + } + return alertSuccess +} + +func (m *certificateRequestMsg13) getSignatureAlgorithmsCert() []SignatureScheme { + return signAlgosCertList(m.supportedSignatureAlgorithms, m.supportedSignatureAlgorithmsCert) +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAndHash bool + signatureAlgorithm SignatureScheme + signature []byte +} + +func (m *certificateVerifyMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateVerifyMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.hasSignatureAndHash == m1.hasSignatureAndHash && + m.signatureAlgorithm == m1.signatureAlgorithm && + bytes.Equal(m.signature, m1.signature) +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc4346#section-7.4.8 + siglength := len(m.signature) + length := 2 + siglength + if m.hasSignatureAndHash { + length += 2 + } + x = make([]byte, 4+length) + x[0] = typeCertificateVerify + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + y := x[4:] + if m.hasSignatureAndHash { + y[0] = uint8(m.signatureAlgorithm >> 8) + y[1] = uint8(m.signatureAlgorithm) + y = y[2:] + } + y[0] = uint8(siglength >> 8) + y[1] = uint8(siglength) + copy(y[2:], m.signature) + + m.raw = x + + return +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) alert { + m.raw = data + + if len(data) < 6 { + return alertDecodeError + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return alertDecodeError + } + + data = data[4:] + if m.hasSignatureAndHash { + m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) + data = data[2:] + } + + if len(data) < 2 { + return alertDecodeError + } + siglength := int(data[0])<<8 + int(data[1]) + data = data[2:] + if len(data) != siglength { + return alertDecodeError + } + + m.signature = data + + return alertSuccess +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) equal(i interface{}) bool { + m1, ok := i.(*newSessionTicketMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ticket, m1.ticket) +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc5077#section-3.3 + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) alert { + m.raw = data + + if len(data) < 10 { + return alertDecodeError + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return alertDecodeError + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return alertDecodeError + } + + m.ticket = data[10:] + + return alertSuccess +} + +type newSessionTicketMsg13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + nonce []byte + ticket []byte + withEarlyDataInfo bool + maxEarlyDataLength uint32 +} + +func (m *newSessionTicketMsg13) equal(i interface{}) bool { + m1, ok := i.(*newSessionTicketMsg13) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.lifetime == m1.lifetime && + m.ageAdd == m1.ageAdd && + bytes.Equal(m.nonce, m1.nonce) && + bytes.Equal(m.ticket, m1.ticket) && + m.withEarlyDataInfo == m1.withEarlyDataInfo && + m.maxEarlyDataLength == m1.maxEarlyDataLength +} + +func (m *newSessionTicketMsg13) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.6.1 + nonceLen := len(m.nonce) + ticketLen := len(m.ticket) + length := 13 + nonceLen + ticketLen + if m.withEarlyDataInfo { + length += 8 + } + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(m.lifetime >> 24) + x[5] = uint8(m.lifetime >> 16) + x[6] = uint8(m.lifetime >> 8) + x[7] = uint8(m.lifetime) + x[8] = uint8(m.ageAdd >> 24) + x[9] = uint8(m.ageAdd >> 16) + x[10] = uint8(m.ageAdd >> 8) + x[11] = uint8(m.ageAdd) + + x[12] = uint8(nonceLen) + copy(x[13:13+nonceLen], m.nonce) + + y := x[13+nonceLen:] + y[0] = uint8(ticketLen >> 8) + y[1] = uint8(ticketLen) + copy(y[2:2+ticketLen], m.ticket) + + if m.withEarlyDataInfo { + z := y[2+ticketLen:] + // z[0] is already 0, this is the extensions vector length. + z[1] = 8 + z[2] = uint8(extensionEarlyData >> 8) + z[3] = uint8(extensionEarlyData) + z[5] = 4 + z[6] = uint8(m.maxEarlyDataLength >> 24) + z[7] = uint8(m.maxEarlyDataLength >> 16) + z[8] = uint8(m.maxEarlyDataLength >> 8) + z[9] = uint8(m.maxEarlyDataLength) + } + + m.raw = x + + return +} + +func (m *newSessionTicketMsg13) unmarshal(data []byte) alert { + m.raw = data + m.maxEarlyDataLength = 0 + m.withEarlyDataInfo = false + + if len(data) < 17 { + return alertDecodeError + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return alertDecodeError + } + + m.lifetime = uint32(data[4])<<24 | uint32(data[5])<<16 | + uint32(data[6])<<8 | uint32(data[7]) + m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 | + uint32(data[10])<<8 | uint32(data[11]) + + nonceLen := int(data[12]) + if nonceLen == 0 || 13+nonceLen+2 > len(data) { + return alertDecodeError + } + m.nonce = data[13 : 13+nonceLen] + + data = data[13+nonceLen:] + ticketLen := int(data[0])<<8 + int(data[1]) + if ticketLen == 0 || 2+ticketLen+2 > len(data) { + return alertDecodeError + } + m.ticket = data[2 : 2+ticketLen] + + data = data[2+ticketLen:] + extLen := int(data[0])<<8 + int(data[1]) + if extLen != len(data)-2 { + return alertDecodeError + } + + data = data[2:] + for len(data) > 0 { + if len(data) < 4 { + return alertDecodeError + } + extType := uint16(data[0])<<8 + uint16(data[1]) + length := int(data[2])<<8 + int(data[3]) + data = data[4:] + + switch extType { + case extensionEarlyData: + if length != 4 { + return alertDecodeError + } + m.withEarlyDataInfo = true + m.maxEarlyDataLength = uint32(data[0])<<24 | uint32(data[1])<<16 | + uint32(data[2])<<8 | uint32(data[3]) + } + data = data[length:] + } + + return alertSuccess +} + +type endOfEarlyDataMsg struct { +} + +func (*endOfEarlyDataMsg) marshal() []byte { + return []byte{typeEndOfEarlyData, 0, 0, 0} +} + +func (*endOfEarlyDataMsg) unmarshal(data []byte) alert { + if len(data) != 4 { + return alertDecodeError + } + return alertSuccess +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) alert { + if len(data) != 4 { + return alertDecodeError + } + return alertSuccess +} + +func eqUint16s(x, y []uint16) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqCurveIDs(x, y []CurveID) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqStrings(x, y []string) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqByteSlices(x, y [][]byte) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if !bytes.Equal(v, y[i]) { + return false + } + } + return true +} + +func eqSignatureAlgorithms(x, y []SignatureScheme) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if v != y[i] { + return false + } + } + return true +} + +func eqKeyShares(x, y []keyShare) bool { + if len(x) != len(y) { + return false + } + for i := range x { + if x[i].group != y[i].group { + return false + } + if !bytes.Equal(x[i].data, y[i].data) { + return false + } + } + return true +} + +func findExtension(data []byte, extensionType uint16) []byte { + for len(data) != 0 { + if len(data) < 4 { + return nil + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return nil + } + if extension == extensionType { + return data[:length] + } + data = data[length:] + } + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls/handshake_server.go b/vendor/github.com/marten-seemann/qtls/handshake_server.go new file mode 100644 index 00000000..5be91f1b --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/handshake_server.go @@ -0,0 +1,943 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "io" + "sync/atomic" +) + +// serverHandshakeState contains details of a server handshake in progress. +// It's discarded once the handshake has completed. +type serverHandshakeState struct { + c *Conn + suite *cipherSuite + masterSecret []byte + cachedClientHelloInfo *ClientHelloInfo + clientHello *clientHelloMsg + hello *serverHelloMsg + cert *Certificate + privateKey crypto.PrivateKey + + // A marshalled DelegatedCredential to be sent to the client in the + // handshake. + delegatedCredential []byte + + // TLS 1.0-1.2 fields + ellipticOk bool + ecdsaOk bool + rsaDecryptOk bool + rsaSignOk bool + sessionState *sessionState + finishedHash finishedHash + certsFromClient [][]byte + + // TLS 1.3 fields + hello13Enc *encryptedExtensionsMsg + keySchedule *keySchedule13 + clientFinishedKey []byte + hsClientTrafficSecret []byte + appClientTrafficSecret []byte +} + +// serverHandshake performs a TLS handshake as a server. +// c.out.Mutex <= L; c.handshakeMutex <= L. +func (c *Conn) serverHandshake() error { + // If this is the first server handshake, we generate a random key to + // encrypt the tickets with. + c.config.serverInitOnce.Do(func() { c.config.serverInit(nil) }) + c.setAlternativeRecordLayer() + + hs := serverHandshakeState{ + c: c, + } + c.in.traceErr = hs.traceErr + c.out.traceErr = hs.traceErr + isResume, err := hs.readClientHello() + if err != nil { + return err + } + + // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 + // and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2 + c.buffering = true + if c.vers >= VersionTLS13 { + if err := hs.doTLS13Handshake(); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.hs = &hs + // If the client is sending early data while the server expects + // it, delay the Finished check until HandshakeConfirmed() is + // called or until all early data is Read(). Otherwise, complete + // authenticating the client now (there is no support for + // sending 0.5-RTT data to a potential unauthenticated client). + if c.phase != readingEarlyData { + if err := hs.readClientFinished13(false); err != nil { + return err + } + } + c.handshakeComplete = true + return nil + } else if isResume { + // The client has included a session ticket and so we do an abbreviated handshake. + if err := hs.doResumeHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + // ticketSupported is set in a resumption handshake if the + // ticket from the client was encrypted with an old session + // ticket key and thus a refreshed ticket should be sent. + if hs.hello.ticketSupported { + if err := hs.sendSessionTicket(); err != nil { + return err + } + } + if err := hs.sendFinished(c.serverFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = false + if err := hs.readFinished(nil); err != nil { + return err + } + c.didResume = true + } else { + // The client didn't include a session ticket, or it wasn't + // valid so we do a full handshake. + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readFinished(c.clientFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = true + c.buffering = true + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(nil); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } + if c.hand.Len() > 0 { + return c.sendAlert(alertUnexpectedMessage) + } + c.phase = handshakeConfirmed + atomic.StoreInt32(&c.handshakeConfirmed, 1) + c.handshakeComplete = true + + return nil +} + +// readClientHello reads a ClientHello message from the client and decides +// whether we will perform session resumption. +func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return false, err + } + var ok bool + hs.clientHello, ok = msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return false, unexpectedMessageError(hs.clientHello, msg) + } + + if c.config.GetConfigForClient != nil { + if newConfig, err := c.config.GetConfigForClient(hs.clientHelloInfo()); err != nil { + c.out.traceErr, c.in.traceErr = nil, nil // disable tracing + c.sendAlert(alertInternalError) + return false, err + } else if newConfig != nil { + newConfig.serverInitOnce.Do(func() { newConfig.serverInit(c.config) }) + c.config = newConfig + } + } + + var keyShares []CurveID + for _, ks := range hs.clientHello.keyShares { + keyShares = append(keyShares, ks.group) + } + + if hs.clientHello.supportedVersions != nil { + c.vers, ok = c.config.pickVersion(hs.clientHello.supportedVersions) + if !ok { + c.sendAlert(alertProtocolVersion) + return false, fmt.Errorf("tls: none of the client versions (%x) are supported", hs.clientHello.supportedVersions) + } + } else { + c.vers, ok = c.config.mutualVersion(hs.clientHello.vers) + if !ok { + c.sendAlert(alertProtocolVersion) + return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) + } + } + c.haveVers = true + + preferredCurves := c.config.curvePreferences() +Curves: + for _, curve := range hs.clientHello.supportedCurves { + for _, supported := range preferredCurves { + if supported == curve { + hs.ellipticOk = true + break Curves + } + } + } + + // If present, the supported points extension must include uncompressed. + // Can be absent. This behavior mirrors BoringSSL. + if hs.clientHello.supportedPoints != nil { + supportedPointFormat := false + for _, pointFormat := range hs.clientHello.supportedPoints { + if pointFormat == pointFormatUncompressed { + supportedPointFormat = true + break + } + } + if !supportedPointFormat { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: client does not support uncompressed points") + } + } + + foundCompression := false + // We only support null compression, so check that the client offered it. + for _, compression := range hs.clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true + break + } + } + + if !foundCompression { + c.sendAlert(alertIllegalParameter) + return false, errors.New("tls: client does not support uncompressed connections") + } + if len(hs.clientHello.compressionMethods) != 1 && c.vers >= VersionTLS13 { + c.sendAlert(alertIllegalParameter) + return false, errors.New("tls: 1.3 client offered compression") + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + if c.vers < VersionTLS13 { + hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers + hs.hello.random = make([]byte, 32) + _, err = io.ReadFull(c.config.rand(), hs.hello.random) + if err != nil { + c.sendAlert(alertInternalError) + return false, err + } + hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported + hs.hello.compressionMethod = compressionNone + } else { + if hs.c.config.ReceivedExtensions != nil { + if err := hs.c.config.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions); err != nil { + c.sendAlert(alertInternalError) + return false, err + } + } + hs.hello = new(serverHelloMsg) + hs.hello13Enc = new(encryptedExtensionsMsg) + if hs.c.config.GetExtensions != nil { + hs.hello13Enc.additionalExtensions = hs.c.config.GetExtensions(typeEncryptedExtensions) + } + hs.hello.vers = c.vers + hs.hello.random = make([]byte, 32) + hs.hello.sessionId = hs.clientHello.sessionId + _, err = io.ReadFull(c.config.rand(), hs.hello.random) + if err != nil { + c.sendAlert(alertInternalError) + return false, err + } + } + + if len(hs.clientHello.serverName) > 0 { + c.serverName = hs.clientHello.serverName + } + + if len(hs.clientHello.alpnProtocols) > 0 { + if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { + if hs.hello13Enc != nil { + hs.hello13Enc.alpnProtocol = selectedProto + } else { + hs.hello.alpnProtocol = selectedProto + } + c.clientProtocol = selectedProto + } + } else { + // Although sending an empty NPN extension is reasonable, Firefox has + // had a bug around this. Best to send nothing at all if + // c.config.NextProtos is empty. See + // https://golang.org/issue/5445. + if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 && c.vers < VersionTLS13 { + hs.hello.nextProtoNeg = true + hs.hello.nextProtos = c.config.NextProtos + } + } + + hs.cert, err = c.config.getCertificate(hs.clientHelloInfo()) + if err != nil { + c.sendAlert(alertInternalError) + return false, err + } + + // Set the private key for this handshake to the certificate's secret key. + hs.privateKey = hs.cert.PrivateKey + + if hs.clientHello.scts { + hs.hello.scts = hs.cert.SignedCertificateTimestamps + } + + // Set the private key to the DC private key if the client and server are + // willing to negotiate the delegated credential extension. + // + // Check to see if a DelegatedCredential is available and should be used. + // If one is available, the session is using TLS >= 1.2, and the client + // accepts the delegated credential extension, then set the handshake + // private key to the DC private key. + if c.config.GetDelegatedCredential != nil && hs.clientHello.delegatedCredential && c.vers >= VersionTLS12 { + dc, sk, err := c.config.GetDelegatedCredential(hs.clientHelloInfo(), c.vers) + if err != nil { + c.sendAlert(alertInternalError) + return false, err + } + + // Set the handshake private key. + if dc != nil { + hs.privateKey = sk + hs.delegatedCredential = dc + } + } + + if priv, ok := hs.privateKey.(crypto.Signer); ok { + switch priv.Public().(type) { + case *ecdsa.PublicKey: + hs.ecdsaOk = true + case *rsa.PublicKey: + hs.rsaSignOk = true + default: + c.sendAlert(alertInternalError) + return false, fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) + } + } + if priv, ok := hs.privateKey.(crypto.Decrypter); ok { + switch priv.Public().(type) { + case *rsa.PublicKey: + hs.rsaDecryptOk = true + default: + c.sendAlert(alertInternalError) + return false, fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) + } + } + + if c.vers != VersionTLS13 && hs.checkForResumption() { + return true, nil + } + + var preferenceList, supportedList []uint16 + if c.config.PreferServerCipherSuites { + preferenceList = c.config.cipherSuites() + supportedList = hs.clientHello.cipherSuites + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = c.config.cipherSuites() + } + + for _, id := range preferenceList { + if hs.setCipherSuite(id, supportedList, c.vers) { + break + } + } + + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: no cipher suite supported by both client and server") + } + + // See https://tools.ietf.org/html/rfc7507. + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // The client is doing a fallback connection. + if c.vers < c.config.maxVersion() { + c.sendAlert(alertInappropriateFallback) + return false, errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + return false, nil +} + +// checkForResumption reports whether we should perform resumption on this connection. +func (hs *serverHandshakeState) checkForResumption() bool { + c := hs.c + + if c.config.SessionTicketsDisabled { + return false + } + + sessionTicket := append([]uint8{}, hs.clientHello.sessionTicket...) + serializedState, usedOldKey := c.decryptTicket(sessionTicket) + hs.sessionState = &sessionState{usedOldKey: usedOldKey} + if hs.sessionState.unmarshal(serializedState) != alertSuccess { + return false + } + + // Never resume a session for a different TLS version. + if c.vers != hs.sessionState.vers { + return false + } + + // Do not resume connections where client support for EMS has changed + if (hs.clientHello.extendedMSSupported && c.config.UseExtendedMasterSecret) != hs.sessionState.usedEMS { + return false + } + + cipherSuiteOk := false + // Check that the client is still offering the ciphersuite in the session. + for _, id := range hs.clientHello.cipherSuites { + if id == hs.sessionState.cipherSuite { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return false + } + + // Check that we also support the ciphersuite from the session. + if !hs.setCipherSuite(hs.sessionState.cipherSuite, c.config.cipherSuites(), hs.sessionState.vers) { + return false + } + + sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + needClientCerts := c.config.ClientAuth == RequireAnyClientCert || c.config.ClientAuth == RequireAndVerifyClientCert + if needClientCerts && !sessionHasClientCerts { + return false + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + return false + } + + return true +} + +func (hs *serverHandshakeState) doResumeHandshake() error { + c := hs.c + + hs.hello.cipherSuite = hs.suite.id + // We echo the client's session ID in the ServerHello to let it know + // that we're doing a resumption. + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.ticketSupported = hs.sessionState.usedOldKey + hs.hello.extendedMSSupported = hs.clientHello.extendedMSSupported && c.config.UseExtendedMasterSecret + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if len(hs.sessionState.certificates) > 0 { + if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil { + return err + } + } + + hs.masterSecret = hs.sessionState.masterSecret + c.useEMS = hs.sessionState.usedEMS + + return nil +} + +func (hs *serverHandshakeState) doFullHandshake() error { + c := hs.c + + if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { + hs.hello.ocspStapling = true + } + + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled + hs.hello.cipherSuite = hs.suite.id + hs.hello.extendedMSSupported = hs.clientHello.extendedMSSupported && c.config.UseExtendedMasterSecret + + hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) + if c.config.ClientAuth == NoClientCert { + // No need to keep a full record of the handshake if client + // certificates won't be used. + hs.finishedHash.discardHandshakeBuffer() + } + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + certMsg := new(certificateMsg) + certMsg.certificates = hs.cert.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + if hs.hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.statusType = statusTypeOCSP + certStatus.response = hs.cert.OCSPStaple + hs.finishedHash.Write(certStatus.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + return err + } + } + + keyAgreement := hs.suite.ka(c.vers) + skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.privateKey, hs.clientHello, hs.hello) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if skx != nil { + hs.finishedHash.Write(skx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + return err + } + } + + if c.config.ClientAuth >= RequestClientCert { + // Request a client certificate + certReq := new(certificateRequestMsg) + certReq.certificateTypes = []byte{ + byte(certTypeRSASign), + byte(certTypeECDSASign), + } + if c.vers >= VersionTLS12 { + certReq.hasSignatureAndHash = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + hs.finishedHash.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + helloDone := new(serverHelloDoneMsg) + hs.finishedHash.Write(helloDone.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + return err + } + + if _, err := c.flush(); err != nil { + return err + } + + var pub crypto.PublicKey // public key for client auth, if any + + msg, err := c.readHandshake() + if err != nil { + return err + } + + var ok bool + // If we requested a client certificate, then the client must send a + // certificate message, even if it's empty. + if c.config.ClientAuth >= RequestClientCert { + if certMsg, ok = msg.(*certificateMsg); !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + if len(certMsg.certificates) == 0 { + // The client didn't actually send a certificate + switch c.config.ClientAuth { + case RequireAnyClientCert, RequireAndVerifyClientCert: + c.sendAlert(alertBadCertificate) + return errors.New("tls: client didn't provide a certificate") + } + } + + pub, err = hs.processCertsFromClient(certMsg.certificates) + if err != nil { + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + // Get client key exchange + ckx, ok := msg.(*clientKeyExchangeMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) + } + hs.finishedHash.Write(ckx.marshal()) + + preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.privateKey, ckx, c.vers) + if err != nil { + if err == errClientKeyExchange { + c.sendAlert(alertDecodeError) + } else { + c.sendAlert(alertInternalError) + } + return err + } + c.useEMS = hs.hello.extendedMSSupported + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random, hs.finishedHash, c.useEMS) + if err := c.config.writeKeyLog("CLIENT_RANDOM", hs.clientHello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + + // If we received a client cert in response to our certificate request message, + // the client will send us a certificateVerifyMsg immediately after the + // clientKeyExchangeMsg. This message is a digest of all preceding + // handshake-layer messages that is signed using the private key corresponding + // to the client's certificate. This allows us to verify that the client is in + // possession of the private key of the certificate. + if len(c.peerCertificates) > 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // Determine the signature type. + _, sigType, hashFunc, err := pickSignatureAlgorithm(pub, []SignatureScheme{certVerify.signatureAlgorithm}, supportedSignatureAlgorithms, c.vers) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + + var digest []byte + if digest, err = hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret); err == nil { + err = verifyHandshakeSignature(sigType, pub, hashFunc, digest, certVerify.signature) + } + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: could not validate signature of connection nonces: " + err.Error()) + } + + hs.finishedHash.Write(certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *serverHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + + var clientCipher, serverCipher interface{} + var clientHash, serverHash macFunction + + if hs.suite.aead == nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) + clientHash = hs.suite.mac(c.vers, clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = hs.suite.mac(c.vers, serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) + c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) + + return nil +} + +func (hs *serverHandshakeState) readFinished(out []byte) error { + c := hs.c + + c.readRecord(recordTypeChangeCipherSpec) + if c.in.err != nil { + return c.in.err + } + + if hs.hello.nextProtoNeg { + msg, err := c.readHandshake() + if err != nil { + return err + } + nextProto, ok := msg.(*nextProtoMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(nextProto, msg) + } + hs.finishedHash.Write(nextProto.marshal()) + c.clientProtocol = nextProto.proto + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + + verify := hs.finishedHash.clientSum(hs.masterSecret) + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + c.sendAlert(alertDecryptError) + return errors.New("tls: client's Finished message is incorrect") + } + + hs.finishedHash.Write(clientFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket() error { + if !hs.hello.ticketSupported { + return nil + } + + c := hs.c + m := new(newSessionTicketMsg) + + var err error + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + certificates: hs.certsFromClient, + usedEMS: c.useEMS, + } + m.ticket, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + + hs.finishedHash.Write(m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + c.cipherSuite = hs.suite.id + copy(out, finished.verifyData) + + return nil +} + +// processCertsFromClient takes a chain of client certificates either from a +// Certificates message or from a sessionState and verifies them. It returns +// the public key of the leaf certificate. +func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (crypto.PublicKey, error) { + c := hs.c + + hs.certsFromClient = certificates + certs := make([]*x509.Certificate, len(certificates)) + var err error + for i, asn1Data := range certificates { + if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { + c.sendAlert(alertBadCertificate) + return nil, errors.New("tls: failed to parse client certificate: " + err.Error()) + } + } + + if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { + opts := x509.VerifyOptions{ + Roots: c.config.ClientCAs, + CurrentTime: c.config.time(), + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + chains, err := certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return nil, errors.New("tls: failed to verify client's certificate: " + err.Error()) + } + + c.verifiedChains = chains + } + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return nil, err + } + } + + if len(certs) == 0 { + return nil, nil + } + + var pub crypto.PublicKey + switch key := certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey: + pub = key + default: + c.sendAlert(alertUnsupportedCertificate) + return nil, fmt.Errorf("tls: client's certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + c.peerCertificates = certs + return pub, nil +} + +// setCipherSuite sets a cipherSuite with the given id as the serverHandshakeState +// suite if that cipher suite is acceptable to use. +// It returns a bool indicating if the suite was set. +func (hs *serverHandshakeState) setCipherSuite(id uint16, supportedCipherSuites []uint16, version uint16) bool { + for _, supported := range supportedCipherSuites { + if id == supported { + var candidate *cipherSuite + + for _, s := range cipherSuites { + if s.id == id { + candidate = s + break + } + } + if candidate == nil { + continue + } + + if version >= VersionTLS13 && candidate.flags&suiteTLS13 != 0 { + hs.suite = candidate + return true + } + if version < VersionTLS13 && candidate.flags&suiteTLS13 != 0 { + continue + } + + // Don't select a ciphersuite which we can't + // support for this client. + if candidate.flags&suiteECDHE != 0 { + if !hs.ellipticOk { + continue + } + if candidate.flags&suiteECDSA != 0 { + if !hs.ecdsaOk { + continue + } + } else if !hs.rsaSignOk { + continue + } + } else if !hs.rsaDecryptOk { + continue + } + if version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 { + continue + } + hs.suite = candidate + return true + } + } + return false +} + +// suppVersArray is the backing array of ClientHelloInfo.SupportedVersions +var suppVersArray = [...]uint16{VersionTLS12, VersionTLS11, VersionTLS10, VersionSSL30} + +func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo { + if hs.cachedClientHelloInfo != nil { + return hs.cachedClientHelloInfo + } + + var supportedVersions []uint16 + if hs.clientHello.supportedVersions != nil { + supportedVersions = hs.clientHello.supportedVersions + } else if hs.clientHello.vers > VersionTLS12 { + supportedVersions = suppVersArray[:] + } else if hs.clientHello.vers >= VersionSSL30 { + supportedVersions = suppVersArray[VersionTLS12-hs.clientHello.vers:] + } + + var pskBinder []byte + if len(hs.clientHello.psks) > 0 { + pskBinder = hs.clientHello.psks[0].binder + } + + hs.cachedClientHelloInfo = &ClientHelloInfo{ + CipherSuites: hs.clientHello.cipherSuites, + ServerName: hs.clientHello.serverName, + SupportedCurves: hs.clientHello.supportedCurves, + SupportedPoints: hs.clientHello.supportedPoints, + SignatureSchemes: hs.clientHello.supportedSignatureAlgorithms, + SupportedProtos: hs.clientHello.alpnProtocols, + SupportedVersions: supportedVersions, + Conn: hs.c.conn, + Offered0RTTData: hs.clientHello.earlyData, + AcceptsDelegatedCredential: hs.clientHello.delegatedCredential, + Fingerprint: pskBinder, + } + + return hs.cachedClientHelloInfo +} diff --git a/vendor/github.com/marten-seemann/qtls/hkdf.go b/vendor/github.com/marten-seemann/qtls/hkdf.go new file mode 100644 index 00000000..5503b595 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/hkdf.go @@ -0,0 +1,58 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +// Mostly derived from golang.org/x/crypto/hkdf, but with an exposed +// Extract API. +// +// HKDF is a cryptographic key derivation function (KDF) with the goal of +// expanding limited input keying material into one or more cryptographically +// strong secret keys. +// +// RFC 5869: https://tools.ietf.org/html/rfc5869 + +import ( + "crypto" + "crypto/hmac" +) + +func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte { + var ( + expander = hmac.New(hash.New, prk) + res = make([]byte, l) + counter = byte(1) + prev []byte + ) + + if l > 255*expander.Size() { + panic("hkdf: requested too much output") + } + + p := res + for len(p) > 0 { + expander.Reset() + expander.Write(prev) + expander.Write(info) + expander.Write([]byte{counter}) + prev = expander.Sum(prev[:0]) + counter++ + n := copy(p, prev) + p = p[n:] + } + + return res +} + +func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte { + if salt == nil { + salt = make([]byte, hash.Size()) + } + if secret == nil { + secret = make([]byte, hash.Size()) + } + extractor := hmac.New(hash.New, salt) + extractor.Write(secret) + return extractor.Sum(nil) +} diff --git a/vendor/github.com/marten-seemann/qtls/key_agreement.go b/vendor/github.com/marten-seemann/qtls/key_agreement.go new file mode 100644 index 00000000..6bdbbd94 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/key_agreement.go @@ -0,0 +1,402 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/elliptic" + "crypto/md5" + "crypto/rsa" + "crypto/sha1" + "errors" + "io" + "math/big" + + "golang.org/x/crypto/curve25519" +) + +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + +// rsaKeyAgreement implements the standard TLS key agreement where the client +// encrypts the pre-master secret to the server's public key. +type rsaKeyAgreement struct{} + +func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, sk crypto.PrivateKey, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + return nil, nil +} + +func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, sk crypto.PrivateKey, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + + ciphertext := ckx.ciphertext + if version != VersionSSL30 { + ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) + if ciphertextLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + ciphertext = ckx.ciphertext[2:] + } + priv, ok := sk.(crypto.Decrypter) + if !ok { + return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") + } + // Perform constant time RSA PKCS#1 v1.5 decryption + preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) + if err != nil { + return nil, err + } + // We don't check the version number in the premaster secret. For one, + // by checking it, we would leak information about the validity of the + // encrypted pre-master secret. Secondly, it provides only a small + // benefit against a downgrade attack and some implementations send the + // wrong version anyway. See the discussion at the end of section + // 7.4.7.1 of RFC 4346. + return preMasterSecret, nil +} + +func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, pk crypto.PublicKey, skx *serverKeyExchangeMsg) error { + return errors.New("tls: unexpected ServerKeyExchange") +} + +func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, pk crypto.PublicKey) ([]byte, *clientKeyExchangeMsg, error) { + preMasterSecret := make([]byte, 48) + preMasterSecret[0] = byte(clientHello.vers >> 8) + preMasterSecret[1] = byte(clientHello.vers) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, nil, err + } + + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), pk.(*rsa.PublicKey), preMasterSecret) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, len(encrypted)+2) + ckx.ciphertext[0] = byte(len(encrypted) >> 8) + ckx.ciphertext[1] = byte(len(encrypted)) + copy(ckx.ciphertext[2:], encrypted) + return preMasterSecret, ckx, nil +} + +// sha1Hash calculates a SHA1 hash over the given byte slices. +func sha1Hash(slices [][]byte) []byte { + hsha1 := sha1.New() + for _, slice := range slices { + hsha1.Write(slice) + } + return hsha1.Sum(nil) +} + +// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the +// concatenation of an MD5 and SHA1 hash. +func md5SHA1Hash(slices [][]byte) []byte { + md5sha1 := make([]byte, md5.Size+sha1.Size) + hmd5 := md5.New() + for _, slice := range slices { + hmd5.Write(slice) + } + copy(md5sha1, hmd5.Sum(nil)) + copy(md5sha1[md5.Size:], sha1Hash(slices)) + return md5sha1 +} + +// hashForServerKeyExchange hashes the given slices and returns their digest +// using the given hash function. +func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) ([]byte, error) { + if version >= VersionTLS12 { + h := hashFunc.New() + for _, slice := range slices { + h.Write(slice) + } + digest := h.Sum(nil) + return digest, nil + } + if sigType == signatureECDSA { + return sha1Hash(slices), nil + } + return md5SHA1Hash(slices), nil +} + +func curveForCurveID(id CurveID) (elliptic.Curve, bool) { + switch id { + case CurveP256: + return elliptic.P256(), true + case CurveP384: + return elliptic.P384(), true + case CurveP521: + return elliptic.P521(), true + default: + return nil, false + } + +} + +// ecdheKeyAgreement implements a TLS key agreement where the server +// generates an ephemeral EC public/private key pair and signs it. The +// pre-master secret is then calculated using ECDH. The signature may +// either be ECDSA or RSA. +type ecdheKeyAgreement struct { + version uint16 + isRSA bool + privateKey []byte + curveid CurveID + + // publicKey is used to store the peer's public value when X25519 is + // being used. + publicKey []byte + // x and y are used to store the peer's public value when one of the + // NIST curves is being used. + x, y *big.Int +} + +func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, sk crypto.PrivateKey, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + preferredCurves := config.curvePreferences() + +NextCandidate: + for _, candidate := range preferredCurves { + for _, c := range clientHello.supportedCurves { + if candidate == c { + ka.curveid = c + break NextCandidate + } + } + } + + if ka.curveid == 0 { + return nil, errors.New("tls: no supported elliptic curves offered") + } + + var ecdhePublic []byte + + if ka.curveid == X25519 { + var scalar, public [32]byte + if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil { + return nil, err + } + + curve25519.ScalarBaseMult(&public, &scalar) + ka.privateKey = scalar[:] + ecdhePublic = public[:] + } else { + curve, ok := curveForCurveID(ka.curveid) + if !ok { + return nil, errors.New("tls: preferredCurves includes unsupported curve") + } + + var x, y *big.Int + var err error + ka.privateKey, x, y, err = elliptic.GenerateKey(curve, config.rand()) + if err != nil { + return nil, err + } + ecdhePublic = elliptic.Marshal(curve, x, y) + } + + // http://tools.ietf.org/html/rfc4492#section-5.4 + serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic)) + serverECDHParams[0] = 3 // named curve + serverECDHParams[1] = byte(ka.curveid >> 8) + serverECDHParams[2] = byte(ka.curveid) + serverECDHParams[3] = byte(len(ecdhePublic)) + copy(serverECDHParams[4:], ecdhePublic) + + priv, ok := sk.(crypto.Signer) + if !ok { + return nil, errors.New("tls: certificate private key does not implement crypto.Signer") + } + + signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(priv.Public(), clientHello.supportedSignatureAlgorithms, supportedSignatureAlgorithms, ka.version) + if err != nil { + return nil, err + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + + digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, hello.random, serverECDHParams) + if err != nil { + return nil, err + } + + var sig []byte + signOpts := crypto.SignerOpts(hashFunc) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunc} + } + sig, err = priv.Sign(config.rand(), digest, signOpts) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } + + skx := new(serverKeyExchangeMsg) + sigAndHashLen := 0 + if ka.version >= VersionTLS12 { + sigAndHashLen = 2 + } + skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig)) + copy(skx.key, serverECDHParams) + k := skx.key[len(serverECDHParams):] + if ka.version >= VersionTLS12 { + k[0] = byte(signatureAlgorithm >> 8) + k[1] = byte(signatureAlgorithm) + k = k[2:] + } + k[0] = byte(len(sig) >> 8) + k[1] = byte(len(sig)) + copy(k[2:], sig) + + return skx, nil +} + +func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, sk crypto.PrivateKey, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { + return nil, errClientKeyExchange + } + + if ka.curveid == X25519 { + if len(ckx.ciphertext) != 1+32 { + return nil, errClientKeyExchange + } + + var theirPublic, sharedKey, scalar [32]byte + copy(theirPublic[:], ckx.ciphertext[1:]) + copy(scalar[:], ka.privateKey) + curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic) + return sharedKey[:], nil + } + + curve, ok := curveForCurveID(ka.curveid) + if !ok { + panic("internal error") + } + x, y := elliptic.Unmarshal(curve, ckx.ciphertext[1:]) // Unmarshal also checks whether the given point is on the curve + if x == nil { + return nil, errClientKeyExchange + } + x, _ = curve.ScalarMult(x, y, ka.privateKey) + curveSize := (curve.Params().BitSize + 7) >> 3 + xBytes := x.Bytes() + if len(xBytes) == curveSize { + return xBytes, nil + } + preMasterSecret := make([]byte, curveSize) + copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes) + return preMasterSecret, nil +} + +func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, pk crypto.PublicKey, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 4 { + return errServerKeyExchange + } + if skx.key[0] != 3 { // named curve + return errors.New("tls: server selected unsupported curve") + } + ka.curveid = CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + + publicLen := int(skx.key[3]) + if publicLen+4 > len(skx.key) { + return errServerKeyExchange + } + serverECDHParams := skx.key[:4+publicLen] + publicKey := serverECDHParams[4:] + + sig := skx.key[4+publicLen:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if ka.curveid == X25519 { + if len(publicKey) != 32 { + return errors.New("tls: bad X25519 public value") + } + ka.publicKey = publicKey + } else { + curve, ok := curveForCurveID(ka.curveid) + if !ok { + return errors.New("tls: server selected unsupported curve") + } + ka.x, ka.y = elliptic.Unmarshal(curve, publicKey) // Unmarshal also checks whether the given point is on the curve + if ka.x == nil { + return errServerKeyExchange + } + } + + var signatureAlgorithm SignatureScheme + if ka.version >= VersionTLS12 { + // handle SignatureAndHashAlgorithm + signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) + sig = sig[2:] + if len(sig) < 2 { + return errServerKeyExchange + } + } + _, sigType, hashFunc, err := pickSignatureAlgorithm(pk, []SignatureScheme{signatureAlgorithm}, clientHello.supportedSignatureAlgorithms, ka.version) + if err != nil { + return err + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return errServerKeyExchange + } + + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] + + digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, serverHello.random, serverECDHParams) + if err != nil { + return err + } + return verifyHandshakeSignature(sigType, pk, hashFunc, digest, sig) +} + +func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, pk crypto.PublicKey) ([]byte, *clientKeyExchangeMsg, error) { + if ka.curveid == 0 { + return nil, nil, errors.New("tls: missing ServerKeyExchange message") + } + + var serialized, preMasterSecret []byte + + if ka.curveid == X25519 { + var ourPublic, theirPublic, sharedKey, scalar [32]byte + + if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil { + return nil, nil, err + } + + copy(theirPublic[:], ka.publicKey) + curve25519.ScalarBaseMult(&ourPublic, &scalar) + curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic) + serialized = ourPublic[:] + preMasterSecret = sharedKey[:] + } else { + curve, ok := curveForCurveID(ka.curveid) + if !ok { + panic("internal error") + } + priv, mx, my, err := elliptic.GenerateKey(curve, config.rand()) + if err != nil { + return nil, nil, err + } + x, _ := curve.ScalarMult(ka.x, ka.y, priv) + preMasterSecret = make([]byte, (curve.Params().BitSize+7)>>3) + xBytes := x.Bytes() + copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes) + + serialized = elliptic.Marshal(curve, mx, my) + } + + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, 1+len(serialized)) + ckx.ciphertext[0] = byte(len(serialized)) + copy(ckx.ciphertext[1:], serialized) + + return preMasterSecret, ckx, nil +} diff --git a/vendor/github.com/marten-seemann/qtls/prf.go b/vendor/github.com/marten-seemann/qtls/prf.go new file mode 100644 index 00000000..1a6d3156 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/prf.go @@ -0,0 +1,355 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" +) + +// Split a premaster secret in two as specified in RFC 4346, section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return +} + +// pHash implements the P_hash function, as defined in RFC 4346, section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} + +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, section 5. +func prf10(result, secret, label, seed []byte) { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + + for i, b := range result2 { + result[i] ^= b + } +} + +// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5. +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + pHash(result, secret, labelAndSeed, hashFunc) + } +} + +// prf30 implements the SSL 3.0 pseudo-random function, as defined in +// www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 6. +func prf30(result, secret, label, seed []byte) { + hashSHA1 := sha1.New() + hashMD5 := md5.New() + + done := 0 + i := 0 + // RFC 5246 section 6.3 says that the largest PRF output needed is 128 + // bytes. Since no more ciphersuites will be added to SSLv3, this will + // remain true. Each iteration gives us 16 bytes so 10 iterations will + // be sufficient. + var b [11]byte + for done < len(result) { + for j := 0; j <= i; j++ { + b[j] = 'A' + byte(i) + } + + hashSHA1.Reset() + hashSHA1.Write(b[:i+1]) + hashSHA1.Write(secret) + hashSHA1.Write(seed) + digest := hashSHA1.Sum(nil) + + hashMD5.Reset() + hashMD5.Write(secret) + hashMD5.Write(digest) + + done += copy(result[done:], hashMD5.Sum(nil)) + i++ + } +} + +const ( + tlsRandomLength = 32 // Length of a random nonce in TLS 1.1. + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. +) + +var masterSecretLabel = []byte("master secret") +var keyExpansionLabel = []byte("key expansion") +var clientFinishedLabel = []byte("client finished") +var serverFinishedLabel = []byte("server finished") +var extendedMasterSecretLabel = []byte("extended master secret") + +func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { + switch version { + case VersionSSL30: + return prf30, crypto.Hash(0) + case VersionTLS10, VersionTLS11: + return prf10, crypto.Hash(0) + case VersionTLS12: + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 + default: + panic("unknown version") + } +} + +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + prf, _ := prfAndHashForVersion(version, suite) + return prf +} + +// masterFromPreMasterSecret generates the master secret from the pre-master +// secret. See http://tools.ietf.org/html/rfc5246#section-8.1 +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte, fin finishedHash, ems bool) []byte { + if ems { + session_hash := fin.Sum() + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, extendedMasterSecretLabel, session_hash) + return masterSecret + } else { + seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) + return masterSecret + } +} + +// keysFromMasterSecret generates the connection keys from the master +// secret, given the lengths of the MAC key, cipher key and IV, as defined in +// RFC 2246, section 6.3. +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { + seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) + seed = append(seed, serverRandom...) + seed = append(seed, clientRandom...) + + n := 2*macLen + 2*keyLen + 2*ivLen + keyMaterial := make([]byte, n) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) + clientMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + serverMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + clientKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + serverKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + clientIV = keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + serverIV = keyMaterial[:ivLen] + return +} + +// lookupTLSHash looks up the corresponding crypto.Hash for a given +// hash from a TLS SignatureScheme. +func lookupTLSHash(signatureAlgorithm SignatureScheme) (crypto.Hash, error) { + switch signatureAlgorithm { + case PKCS1WithSHA1, ECDSAWithSHA1: + return crypto.SHA1, nil + case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: + return crypto.SHA256, nil + case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: + return crypto.SHA384, nil + case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: + return crypto.SHA512, nil + default: + return 0, fmt.Errorf("tls: unsupported signature algorithm: %#04x", signatureAlgorithm) + } +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + var buffer []byte + if version == VersionSSL30 || version >= VersionTLS12 { + buffer = []byte{} + } + + prf, hash := prfAndHashForVersion(version, cipherSuite) + if hash != 0 { + return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} + } + + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +func (h *finishedHash) Write(msg []byte) (n int, err error) { + h.client.Write(msg) + h.server.Write(msg) + + if h.version < VersionTLS12 { + h.clientMD5.Write(msg) + h.serverMD5.Write(msg) + } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + + return len(msg), nil +} + +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + +// finishedSum30 calculates the contents of the verify_data member of a SSLv3 +// Finished message given the MD5 and SHA1 hashes of a set of handshake +// messages. +func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic []byte) []byte { + md5.Write(magic) + md5.Write(masterSecret) + md5.Write(ssl30Pad1[:]) + md5Digest := md5.Sum(nil) + + md5.Reset() + md5.Write(masterSecret) + md5.Write(ssl30Pad2[:]) + md5.Write(md5Digest) + md5Digest = md5.Sum(nil) + + sha1.Write(magic) + sha1.Write(masterSecret) + sha1.Write(ssl30Pad1[:40]) + sha1Digest := sha1.Sum(nil) + + sha1.Reset() + sha1.Write(masterSecret) + sha1.Write(ssl30Pad2[:40]) + sha1.Write(sha1Digest) + sha1Digest = sha1.Sum(nil) + + ret := make([]byte, len(md5Digest)+len(sha1Digest)) + copy(ret, md5Digest) + copy(ret[len(md5Digest):], sha1Digest) + return ret +} + +var ssl3ClientFinishedMagic = [4]byte{0x43, 0x4c, 0x4e, 0x54} +var ssl3ServerFinishedMagic = [4]byte{0x53, 0x52, 0x56, 0x52} + +// clientSum returns the contents of the verify_data member of a client's +// Finished message. +func (h finishedHash) clientSum(masterSecret []byte) []byte { + if h.version == VersionSSL30 { + return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic[:]) + } + + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) + return out +} + +// serverSum returns the contents of the verify_data member of a server's +// Finished message. +func (h finishedHash) serverSum(masterSecret []byte) []byte { + if h.version == VersionSSL30 { + return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic[:]) + } + + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) + return out +} + +// hashForClientCertificate returns a digest over the handshake messages so far, +// suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) ([]byte, error) { + if (h.version == VersionSSL30 || h.version >= VersionTLS12) && h.buffer == nil { + panic("a handshake hash for a client-certificate was requested after discarding the handshake buffer") + } + + if h.version == VersionSSL30 { + if sigType != signaturePKCS1v15 { + return nil, errors.New("tls: unsupported signature type for client certificate") + } + + md5Hash := md5.New() + md5Hash.Write(h.buffer) + sha1Hash := sha1.New() + sha1Hash.Write(h.buffer) + return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), nil + } + if h.version >= VersionTLS12 { + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil), nil + } + + if sigType == signatureECDSA { + return h.server.Sum(nil), nil + } + + return h.Sum(), nil +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil +} diff --git a/vendor/github.com/marten-seemann/qtls/subcerts.go b/vendor/github.com/marten-seemann/qtls/subcerts.go new file mode 100644 index 00000000..e1cfaf43 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/subcerts.go @@ -0,0 +1,392 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +// Delegated credentials for TLS +// (https://tools.ietf.org/html/draft-ietf-tls-subcerts-02) is an IETF Internet +// draft and proposed TLS extension. This allows a backend server to delegate +// TLS termination to a trusted frontend. If the client supports this extension, +// then the frontend may use a "delegated credential" as the signing key in the +// handshake. A delegated credential is a short lived key pair delegated to the +// server by an entity trusted by the client. Once issued, credentials can't be +// revoked; in order to mitigate risk in case the frontend is compromised, the +// credential is only valid for a short time (days, hours, or even minutes). +// +// This implements draft 02. This draft doesn't specify an object identifier for +// the X.509 extension; we use one assigned by Cloudflare. In addition, IANA has +// not assigned an extension ID for this extension; we picked up one that's not +// yet taken. +// +// TODO(cjpatton) Only ECDSA is supported with delegated credentials for now; +// we'd like to suppoort for EcDSA signatures once these have better support +// upstream. + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/x509" + "encoding/asn1" + "encoding/binary" + "errors" + "fmt" + "time" +) + +const ( + // length of the public key field + dcPubKeyFieldLen = 3 + dcMaxTTLSeconds = 60 * 60 * 24 * 7 // 7 days + dcMaxTTL = time.Duration(dcMaxTTLSeconds * time.Second) + dcMaxPublicKeyLen = 1 << 24 // Bytes + dcMaxSignatureLen = 1 << 16 // Bytes +) + +var errNoDelegationUsage = errors.New("certificate not authorized for delegation") + +// delegationUsageId is the DelegationUsage X.509 extension OID +// +// NOTE(cjpatton) This OID is a child of Cloudflare's IANA-assigned OID. +var delegationUsageId = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44} + +// canDelegate returns true if a certificate can be used for delegated +// credentials. +func canDelegate(cert *x509.Certificate) bool { + // Check that the digitalSignature key usage is set. + if (cert.KeyUsage & x509.KeyUsageDigitalSignature) == 0 { + return false + } + + // Check that the certificate has the DelegationUsage extension and that + // it's non-critical (per the spec). + for _, extension := range cert.Extensions { + if extension.Id.Equal(delegationUsageId) { + return true + } + } + return false +} + +// credential stores the public components of a credential. +type credential struct { + // The serialized form of the credential. + raw []byte + + // The amount of time for which the credential is valid. Specifically, the + // the credential expires `ValidTime` seconds after the `notBefore` of the + // delegation certificate. The delegator shall not issue delegated + // credentials that are valid for more than 7 days from the current time. + // + // When this data structure is serialized, this value is converted to a + // uint32 representing the duration in seconds. + validTime time.Duration + + // The signature scheme associated with the delegated credential public key. + expectedCertVerifyAlgorithm SignatureScheme + + // The version of TLS in which the credential will be used. + expectedVersion uint16 + + // The credential public key. + publicKey crypto.PublicKey +} + +// isExpired returns true if the credential has expired. The end of the validity +// interval is defined as the delegator certificate's notBefore field (`start`) +// plus ValidTime seconds. This function simply checks that the current time +// (`now`) is before the end of the valdity interval. +func (cred *credential) isExpired(start, now time.Time) bool { + end := start.Add(cred.validTime) + return !now.Before(end) +} + +// invalidTTL returns true if the credential's validity period is longer than the +// maximum permitted. This is defined by the certificate's notBefore field +// (`start`) plus the ValidTime, minus the current time (`now`). +func (cred *credential) invalidTTL(start, now time.Time) bool { + return cred.validTime > (now.Sub(start) + dcMaxTTL).Round(time.Second) +} + +// marshalSubjectPublicKeyInfo returns a DER encoded SubjectPublicKeyInfo structure +// (as defined in the X.509 standard) for the credential. +func (cred *credential) marshalSubjectPublicKeyInfo() ([]byte, error) { + switch cred.expectedCertVerifyAlgorithm { + case ECDSAWithP256AndSHA256, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512: + serializedPublicKey, err := x509.MarshalPKIXPublicKey(cred.publicKey) + if err != nil { + return nil, err + } + return serializedPublicKey, nil + + default: + return nil, fmt.Errorf("unsupported signature scheme: 0x%04x", cred.expectedCertVerifyAlgorithm) + } +} + +// marshal encodes a credential in the wire format specified in +// https://tools.ietf.org/html/draft-ietf-tls-subcerts-02. +func (cred *credential) marshal() ([]byte, error) { + // The number of bytes comprising the DC parameters, which includes the + // validity time (4 bytes), the signature scheme of the public key (2 bytes), and + // the protocol version (2 bytes). + paramsLen := 8 + + // The first 4 bytes are the valid_time, scheme, and version fields. + serialized := make([]byte, paramsLen+dcPubKeyFieldLen) + binary.BigEndian.PutUint32(serialized, uint32(cred.validTime/time.Second)) + binary.BigEndian.PutUint16(serialized[4:], uint16(cred.expectedCertVerifyAlgorithm)) + binary.BigEndian.PutUint16(serialized[6:], cred.expectedVersion) + + // Encode the public key and assert that the encoding is no longer than 2^16 + // bytes (per the spec). + serializedPublicKey, err := cred.marshalSubjectPublicKeyInfo() + if err != nil { + return nil, err + } + if len(serializedPublicKey) > dcMaxPublicKeyLen { + return nil, errors.New("public key is too long") + } + + // The next 3 bytes are the length of the public key field, which may be up + // to 2^24 bytes long. + putUint24(serialized[paramsLen:], len(serializedPublicKey)) + + // The remaining bytes are the public key itself. + serialized = append(serialized, serializedPublicKey...) + cred.raw = serialized + return serialized, nil +} + +// unmarshalCredential decodes a credential and returns it. +func unmarshalCredential(serialized []byte) (*credential, error) { + // The number of bytes comprising the DC parameters. + paramsLen := 8 + + if len(serialized) < paramsLen+dcPubKeyFieldLen { + return nil, errors.New("credential is too short") + } + + // Parse the valid_time, scheme, and version fields. + validTime := time.Duration(binary.BigEndian.Uint32(serialized)) * time.Second + scheme := SignatureScheme(binary.BigEndian.Uint16(serialized[4:])) + version := binary.BigEndian.Uint16(serialized[6:]) + + // Parse the SubjectPublicKeyInfo. + pk, err := x509.ParsePKIXPublicKey(serialized[paramsLen+dcPubKeyFieldLen:]) + if err != nil { + return nil, err + } + + if _, ok := pk.(*ecdsa.PublicKey); !ok { + return nil, fmt.Errorf("unsupported delegation key type: %T", pk) + } + + return &credential{ + raw: serialized, + validTime: validTime, + expectedCertVerifyAlgorithm: scheme, + expectedVersion: version, + publicKey: pk, + }, nil +} + +// getCredentialLen returns the number of bytes comprising the serialized +// credential that starts at the beginning of the input slice. It returns an +// error if the input is too short to contain a credential. +func getCredentialLen(serialized []byte) (int, error) { + paramsLen := 8 + if len(serialized) < paramsLen+dcPubKeyFieldLen { + return 0, errors.New("credential is too short") + } + // First several bytes are the valid_time, scheme, and version fields. + serialized = serialized[paramsLen:] + + // The next 3 bytes are the length of the serialized public key, which may + // be up to 2^24 bytes in length. + serializedPublicKeyLen := getUint24(serialized) + serialized = serialized[dcPubKeyFieldLen:] + + if len(serialized) < serializedPublicKeyLen { + return 0, errors.New("public key of credential is too short") + } + + return paramsLen + dcPubKeyFieldLen + serializedPublicKeyLen, nil +} + +// delegatedCredential stores a credential and its delegation. +type delegatedCredential struct { + raw []byte + + // The credential, which contains a public and its validity time. + cred *credential + + // The signature scheme used to sign the credential. + algorithm SignatureScheme + + // The credential's delegation. + signature []byte +} + +// ensureCertificateHasLeaf parses the leaf certificate if needed. +func ensureCertificateHasLeaf(cert *Certificate) error { + var err error + if cert.Leaf == nil { + if len(cert.Certificate[0]) == 0 { + return errors.New("missing leaf certificate") + } + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return err + } + } + return nil +} + +// validate checks that that the signature is valid, that the credential hasn't +// expired, and that the TTL is valid. It also checks that certificate can be +// used for delegation. +func (dc *delegatedCredential) validate(cert *x509.Certificate, now time.Time) (bool, error) { + // Check that the cert can delegate. + if !canDelegate(cert) { + return false, errNoDelegationUsage + } + + if dc.cred.isExpired(cert.NotBefore, now) { + return false, errors.New("credential has expired") + } + + if dc.cred.invalidTTL(cert.NotBefore, now) { + return false, errors.New("credential TTL is invalid") + } + + // Prepare the credential for verification. + rawCred, err := dc.cred.marshal() + if err != nil { + return false, err + } + hash := getHash(dc.algorithm) + in := prepareDelegation(hash, rawCred, cert.Raw, dc.algorithm) + + // TODO(any) This code overlaps significantly with verifyHandshakeSignature() + // in ../auth.go. This should be refactored. + switch dc.algorithm { + case ECDSAWithP256AndSHA256, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512: + pk, ok := cert.PublicKey.(*ecdsa.PublicKey) + if !ok { + return false, errors.New("expected ECDSA public key") + } + sig := new(ecdsaSignature) + if _, err = asn1.Unmarshal(dc.signature, sig); err != nil { + return false, err + } + return ecdsa.Verify(pk, in, sig.R, sig.S), nil + + default: + return false, fmt.Errorf( + "unsupported signature scheme: 0x%04x", dc.algorithm) + } +} + +// unmarshalDelegatedCredential decodes a DelegatedCredential structure. +func unmarshalDelegatedCredential(serialized []byte) (*delegatedCredential, error) { + // Get the length of the serialized credential that begins at the start of + // the input slice. + serializedCredentialLen, err := getCredentialLen(serialized) + if err != nil { + return nil, err + } + + // Parse the credential. + cred, err := unmarshalCredential(serialized[:serializedCredentialLen]) + if err != nil { + return nil, err + } + + // Parse the signature scheme. + serialized = serialized[serializedCredentialLen:] + if len(serialized) < 4 { + return nil, errors.New("delegated credential is too short") + } + scheme := SignatureScheme(binary.BigEndian.Uint16(serialized)) + + // Parse the signature length. + serialized = serialized[2:] + serializedSignatureLen := binary.BigEndian.Uint16(serialized) + + // Prase the signature. + serialized = serialized[2:] + if len(serialized) < int(serializedSignatureLen) { + return nil, errors.New("signature of delegated credential is too short") + } + sig := serialized[:serializedSignatureLen] + + return &delegatedCredential{ + raw: serialized, + cred: cred, + algorithm: scheme, + signature: sig, + }, nil +} + +// getCurve maps the SignatureScheme to its corresponding elliptic.Curve. +func getCurve(scheme SignatureScheme) elliptic.Curve { + switch scheme { + case ECDSAWithP256AndSHA256: + return elliptic.P256() + case ECDSAWithP384AndSHA384: + return elliptic.P384() + case ECDSAWithP521AndSHA512: + return elliptic.P521() + default: + return nil + } +} + +// getHash maps the SignatureScheme to its corresponding hash function. +// +// TODO(any) This function overlaps with hashForSignatureScheme in 13.go. +func getHash(scheme SignatureScheme) crypto.Hash { + switch scheme { + case ECDSAWithP256AndSHA256: + return crypto.SHA256 + case ECDSAWithP384AndSHA384: + return crypto.SHA384 + case ECDSAWithP521AndSHA512: + return crypto.SHA512 + default: + return 0 // Unknown hash function + } +} + +// prepareDelegation returns a hash of the message that the delegator is to +// sign. The inputs are the credential (`cred`), the DER-encoded delegator +// certificate (`delegatorCert`) and the signature scheme of the delegator +// (`delegatorAlgorithm`). +func prepareDelegation(hash crypto.Hash, cred, delegatorCert []byte, delegatorAlgorithm SignatureScheme) []byte { + h := hash.New() + + // The header. + h.Write(bytes.Repeat([]byte{0x20}, 64)) + h.Write([]byte("TLS, server delegated credentials")) + h.Write([]byte{0x00}) + + // The delegation certificate. + h.Write(delegatorCert) + + // The credential. + h.Write(cred) + + // The delegator signature scheme. + var serializedScheme [2]byte + binary.BigEndian.PutUint16(serializedScheme[:], uint16(delegatorAlgorithm)) + h.Write(serializedScheme[:]) + + return h.Sum(nil) +} diff --git a/vendor/github.com/marten-seemann/qtls/ticket.go b/vendor/github.com/marten-seemann/qtls/ticket.go new file mode 100644 index 00000000..e5bffa99 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/ticket.go @@ -0,0 +1,326 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "errors" + "io" +) + +// A SessionTicketSealer provides a way to securely encapsulate +// session state for storage on the client. All methods are safe for +// concurrent use. +type SessionTicketSealer interface { + // Seal returns a session ticket value that can be later passed to Unseal + // to recover the content, usually by encrypting it. The ticket will be sent + // to the client to be stored, and will be sent back in plaintext, so it can + // be read and modified by an attacker. + Seal(cs *ConnectionState, content []byte) (ticket []byte, err error) + + // Unseal returns a session ticket contents. The ticket can't be safely + // assumed to have been generated by Seal. + // If unable to unseal the ticket, the connection will proceed with a + // complete handshake. + Unseal(chi *ClientHelloInfo, ticket []byte) (content []byte, success bool) +} + +// sessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type sessionState struct { + vers uint16 + cipherSuite uint16 + usedEMS bool + masterSecret []byte + certificates [][]byte + // usedOldKey is true if the ticket from which this session came from + // was encrypted with an older key and thus should be refreshed. + usedOldKey bool +} + +func (s *sessionState) equal(i interface{}) bool { + s1, ok := i.(*sessionState) + if !ok { + return false + } + + if s.vers != s1.vers || + s.usedEMS != s1.usedEMS || + s.cipherSuite != s1.cipherSuite || + !bytes.Equal(s.masterSecret, s1.masterSecret) { + return false + } + + if len(s.certificates) != len(s1.certificates) { + return false + } + + for i := range s.certificates { + if !bytes.Equal(s.certificates[i], s1.certificates[i]) { + return false + } + } + + return true +} + +func (s *sessionState) marshal() []byte { + length := 2 + 2 + 2 + len(s.masterSecret) + 2 + for _, cert := range s.certificates { + length += 4 + len(cert) + } + + ret := make([]byte, length) + x := ret + was_used := byte(0) + if s.usedEMS { + was_used = byte(0x80) + } + + x[0] = byte(s.vers>>8) | byte(was_used) + x[1] = byte(s.vers) + x[2] = byte(s.cipherSuite >> 8) + x[3] = byte(s.cipherSuite) + x[4] = byte(len(s.masterSecret) >> 8) + x[5] = byte(len(s.masterSecret)) + x = x[6:] + copy(x, s.masterSecret) + x = x[len(s.masterSecret):] + + x[0] = byte(len(s.certificates) >> 8) + x[1] = byte(len(s.certificates)) + x = x[2:] + + for _, cert := range s.certificates { + x[0] = byte(len(cert) >> 24) + x[1] = byte(len(cert) >> 16) + x[2] = byte(len(cert) >> 8) + x[3] = byte(len(cert)) + copy(x[4:], cert) + x = x[4+len(cert):] + } + + return ret +} + +func (s *sessionState) unmarshal(data []byte) alert { + if len(data) < 8 { + return alertDecodeError + } + + s.vers = (uint16(data[0])<<8 | uint16(data[1])) & 0x7fff + s.cipherSuite = uint16(data[2])<<8 | uint16(data[3]) + s.usedEMS = (data[0] & 0x80) == 0x80 + masterSecretLen := int(data[4])<<8 | int(data[5]) + data = data[6:] + if len(data) < masterSecretLen { + return alertDecodeError + } + + s.masterSecret = data[:masterSecretLen] + data = data[masterSecretLen:] + + if len(data) < 2 { + return alertDecodeError + } + + numCerts := int(data[0])<<8 | int(data[1]) + data = data[2:] + + s.certificates = make([][]byte, numCerts) + for i := range s.certificates { + if len(data) < 4 { + return alertDecodeError + } + certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + data = data[4:] + if certLen < 0 { + return alertDecodeError + } + if len(data) < certLen { + return alertDecodeError + } + s.certificates[i] = data[:certLen] + data = data[certLen:] + } + + if len(data) != 0 { + return alertDecodeError + } + return alertSuccess +} + +type sessionState13 struct { + vers uint16 + suite uint16 + ageAdd uint32 + createdAt uint64 + maxEarlyDataLen uint32 + pskSecret []byte + alpnProtocol string + SNI string +} + +func (s *sessionState13) equal(i interface{}) bool { + s1, ok := i.(*sessionState13) + if !ok { + return false + } + + return s.vers == s1.vers && + s.suite == s1.suite && + s.ageAdd == s1.ageAdd && + s.createdAt == s1.createdAt && + s.maxEarlyDataLen == s1.maxEarlyDataLen && + subtle.ConstantTimeCompare(s.pskSecret, s1.pskSecret) == 1 && + s.alpnProtocol == s1.alpnProtocol && + s.SNI == s1.SNI +} + +func (s *sessionState13) marshal() []byte { + length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.pskSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI) + + x := make([]byte, length) + x[0] = byte(s.vers >> 8) + x[1] = byte(s.vers) + x[2] = byte(s.suite >> 8) + x[3] = byte(s.suite) + x[4] = byte(s.ageAdd >> 24) + x[5] = byte(s.ageAdd >> 16) + x[6] = byte(s.ageAdd >> 8) + x[7] = byte(s.ageAdd) + x[8] = byte(s.createdAt >> 56) + x[9] = byte(s.createdAt >> 48) + x[10] = byte(s.createdAt >> 40) + x[11] = byte(s.createdAt >> 32) + x[12] = byte(s.createdAt >> 24) + x[13] = byte(s.createdAt >> 16) + x[14] = byte(s.createdAt >> 8) + x[15] = byte(s.createdAt) + x[16] = byte(s.maxEarlyDataLen >> 24) + x[17] = byte(s.maxEarlyDataLen >> 16) + x[18] = byte(s.maxEarlyDataLen >> 8) + x[19] = byte(s.maxEarlyDataLen) + x[20] = byte(len(s.pskSecret) >> 8) + x[21] = byte(len(s.pskSecret)) + copy(x[22:], s.pskSecret) + z := x[22+len(s.pskSecret):] + z[0] = byte(len(s.alpnProtocol) >> 8) + z[1] = byte(len(s.alpnProtocol)) + copy(z[2:], s.alpnProtocol) + z = z[2+len(s.alpnProtocol):] + z[0] = byte(len(s.SNI) >> 8) + z[1] = byte(len(s.SNI)) + copy(z[2:], s.SNI) + + return x +} + +func (s *sessionState13) unmarshal(data []byte) alert { + if len(data) < 24 { + return alertDecodeError + } + + s.vers = uint16(data[0])<<8 | uint16(data[1]) + s.suite = uint16(data[2])<<8 | uint16(data[3]) + s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) + s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 | + uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15]) + s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19]) + + l := int(data[20])<<8 | int(data[21]) + if len(data) < 22+l+2 { + return alertDecodeError + } + s.pskSecret = data[22 : 22+l] + z := data[22+l:] + + l = int(z[0])<<8 | int(z[1]) + if len(z) < 2+l+2 { + return alertDecodeError + } + s.alpnProtocol = string(z[2 : 2+l]) + z = z[2+l:] + + l = int(z[0])<<8 | int(z[1]) + if len(z) != 2+l { + return alertDecodeError + } + s.SNI = string(z[2 : 2+l]) + + return alertSuccess +} + +func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) { + encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size) + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + return nil, err + } + key := c.config.ticketKeys()[0] + copy(keyName, key.keyName[:]) + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized) + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func (c *Conn) decryptTicket(encrypted []byte) (serialized []byte, usedOldKey bool) { + if c.config.SessionTicketsDisabled || + len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { + return nil, false + } + + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + keys := c.config.ticketKeys() + keyIndex := -1 + for i, candidateKey := range keys { + if bytes.Equal(keyName, candidateKey.keyName[:]) { + keyIndex = i + break + } + } + + if keyIndex == -1 { + return nil, false + } + key := &keys[keyIndex] + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + expected := mac.Sum(nil) + + if subtle.ConstantTimeCompare(macBytes, expected) != 1 { + return nil, false + } + + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, false + } + ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] + plaintext := ciphertext + cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) + + return plaintext, keyIndex > 0 +} diff --git a/vendor/github.com/marten-seemann/qtls/tls.go b/vendor/github.com/marten-seemann/qtls/tls.go new file mode 100644 index 00000000..0dd6484d --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls/tls.go @@ -0,0 +1,297 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tls partially implements TLS 1.2, as specified in RFC 5246. +package qtls + +// BUG(agl): The crypto/tls package only implements some countermeasures +// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 +// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "net" + "strings" + "time" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config) *Conn { + return &Conn{conn: conn, config: config} +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config) *Conn { + return &Conn{conn: conn, config: config, isClient: true} +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type listener struct { + net.Listener + config *Config +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection is of type *Conn. +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return Server(c, l.config), nil +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config) net.Listener { + l := new(listener) + l.Listener = inner + l.config = config + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config) (net.Listener, error) { + if config == nil || (len(config.Certificates) == 0 && config.GetCertificate == nil) { + return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config), nil +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := time.Until(dialer.Deadline) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- timeoutError{} + }) + } + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = defaultConfig() + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := config.Clone() + c.ServerName = hostname + config = c + } + + conn := Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + } + + if err != nil { + rawConn.Close() + return nil, err + } + + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config) +} + +// LoadX509KeyPair reads and parses a public/private key pair from a pair +// of files. The files must contain PEM encoded data. The certificate file +// may contain intermediate certificates following the leaf certificate to +// form a certificate chain. On successful return, Certificate.Leaf will +// be nil because the parsed form of the certificate is not retained. +func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { + certPEMBlock, err := ioutil.ReadFile(certFile) + if err != nil { + return Certificate{}, err + } + keyPEMBlock, err := ioutil.ReadFile(keyFile) + if err != nil { + return Certificate{}, err + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} + +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. On successful return, Certificate.Leaf will be nil because +// the parsed form of the certificate is not retained. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + fail := func(err error) (Certificate, error) { return Certificate{}, err } + + var cert Certificate + var skippedBlockTypes []string + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } else { + skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) + } + } + + if len(cert.Certificate) == 0 { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in certificate input")) + } + if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { + return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) + } + return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + + skippedBlockTypes = skippedBlockTypes[:0] + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) + if keyDERBlock == nil { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in key input")) + } + if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { + return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) + } + return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) + } + + var err error + cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) + if err != nil { + return fail(err) + } + + // We don't need to parse the public key for TLS, but we so do anyway + // to check that it looks sane and matches the private key. + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fail(err) + } + + switch pub := x509Cert.PublicKey.(type) { + case *rsa.PublicKey: + priv, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.N.Cmp(priv.N) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case *ecdsa.PublicKey: + priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + default: + return fail(errors.New("tls: unknown public key algorithm")) + } + + return cert, nil +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey: + return key, nil + default: + return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("tls: failed to parse private key") +} diff --git a/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305.go b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305.go new file mode 100644 index 00000000..bbb86efe --- /dev/null +++ b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305.go @@ -0,0 +1,101 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD as specified in RFC 7539, +// and its extended nonce variant XChaCha20-Poly1305. +package chacha20poly1305 // import "golang.org/x/crypto/chacha20poly1305" + +import ( + "crypto/cipher" + "encoding/binary" + "errors" +) + +const ( + // KeySize is the size of the key used by this AEAD, in bytes. + KeySize = 32 + + // NonceSize is the size of the nonce used with the standard variant of this + // AEAD, in bytes. + // + // Note that this is too short to be safely generated at random if the same + // key is reused more than 2³² times. + NonceSize = 12 + + // NonceSizeX is the size of the nonce used with the XChaCha20-Poly1305 + // variant of this AEAD, in bytes. + NonceSizeX = 24 +) + +type chacha20poly1305 struct { + key [8]uint32 +} + +// New returns a ChaCha20-Poly1305 AEAD that uses the given 256-bit key. +func New(key []byte) (cipher.AEAD, error) { + if len(key) != KeySize { + return nil, errors.New("chacha20poly1305: bad key length") + } + ret := new(chacha20poly1305) + ret.key[0] = binary.LittleEndian.Uint32(key[0:4]) + ret.key[1] = binary.LittleEndian.Uint32(key[4:8]) + ret.key[2] = binary.LittleEndian.Uint32(key[8:12]) + ret.key[3] = binary.LittleEndian.Uint32(key[12:16]) + ret.key[4] = binary.LittleEndian.Uint32(key[16:20]) + ret.key[5] = binary.LittleEndian.Uint32(key[20:24]) + ret.key[6] = binary.LittleEndian.Uint32(key[24:28]) + ret.key[7] = binary.LittleEndian.Uint32(key[28:32]) + return ret, nil +} + +func (c *chacha20poly1305) NonceSize() int { + return NonceSize +} + +func (c *chacha20poly1305) Overhead() int { + return 16 +} + +func (c *chacha20poly1305) Seal(dst, nonce, plaintext, additionalData []byte) []byte { + if len(nonce) != NonceSize { + panic("chacha20poly1305: bad nonce length passed to Seal") + } + + if uint64(len(plaintext)) > (1<<38)-64 { + panic("chacha20poly1305: plaintext too large") + } + + return c.seal(dst, nonce, plaintext, additionalData) +} + +var errOpen = errors.New("chacha20poly1305: message authentication failed") + +func (c *chacha20poly1305) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { + if len(nonce) != NonceSize { + panic("chacha20poly1305: bad nonce length passed to Open") + } + if len(ciphertext) < 16 { + return nil, errOpen + } + if uint64(len(ciphertext)) > (1<<38)-48 { + panic("chacha20poly1305: ciphertext too large") + } + + return c.open(dst, nonce, ciphertext, additionalData) +} + +// sliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} diff --git a/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_amd64.go b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_amd64.go new file mode 100644 index 00000000..2aa4fd89 --- /dev/null +++ b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_amd64.go @@ -0,0 +1,86 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7,amd64,!gccgo,!appengine + +package chacha20poly1305 + +import ( + "encoding/binary" + + "golang.org/x/crypto/internal/subtle" + "golang.org/x/sys/cpu" +) + +//go:noescape +func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool + +//go:noescape +func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte) + +var ( + useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI2 +) + +// setupState writes a ChaCha20 input matrix to state. See +// https://tools.ietf.org/html/rfc7539#section-2.3. +func setupState(state *[16]uint32, key *[8]uint32, nonce []byte) { + state[0] = 0x61707865 + state[1] = 0x3320646e + state[2] = 0x79622d32 + state[3] = 0x6b206574 + + state[4] = key[0] + state[5] = key[1] + state[6] = key[2] + state[7] = key[3] + state[8] = key[4] + state[9] = key[5] + state[10] = key[6] + state[11] = key[7] + + state[12] = 0 + state[13] = binary.LittleEndian.Uint32(nonce[:4]) + state[14] = binary.LittleEndian.Uint32(nonce[4:8]) + state[15] = binary.LittleEndian.Uint32(nonce[8:12]) +} + +func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte { + if !cpu.X86.HasSSSE3 { + return c.sealGeneric(dst, nonce, plaintext, additionalData) + } + + var state [16]uint32 + setupState(&state, &c.key, nonce) + + ret, out := sliceForAppend(dst, len(plaintext)+16) + if subtle.InexactOverlap(out, plaintext) { + panic("chacha20poly1305: invalid buffer overlap") + } + chacha20Poly1305Seal(out[:], state[:], plaintext, additionalData) + return ret +} + +func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { + if !cpu.X86.HasSSSE3 { + return c.openGeneric(dst, nonce, ciphertext, additionalData) + } + + var state [16]uint32 + setupState(&state, &c.key, nonce) + + ciphertext = ciphertext[:len(ciphertext)-16] + ret, out := sliceForAppend(dst, len(ciphertext)) + if subtle.InexactOverlap(out, ciphertext) { + panic("chacha20poly1305: invalid buffer overlap") + } + if !chacha20Poly1305Open(out, state[:], ciphertext, additionalData) { + for i := range out { + out[i] = 0 + } + return nil, errOpen + } + + return ret, nil +} diff --git a/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_amd64.s b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_amd64.s new file mode 100644 index 00000000..af76bbcc --- /dev/null +++ b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_amd64.s @@ -0,0 +1,2695 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file was originally from https://golang.org/cl/24717 by Vlad Krasnov of CloudFlare. + +// +build go1.7,amd64,!gccgo,!appengine + +#include "textflag.h" +// General register allocation +#define oup DI +#define inp SI +#define inl BX +#define adp CX // free to reuse, after we hash the additional data +#define keyp R8 // free to reuse, when we copy the key to stack +#define itr2 R9 // general iterator +#define itr1 CX // general iterator +#define acc0 R10 +#define acc1 R11 +#define acc2 R12 +#define t0 R13 +#define t1 R14 +#define t2 R15 +#define t3 R8 +// Register and stack allocation for the SSE code +#define rStore (0*16)(BP) +#define sStore (1*16)(BP) +#define state1Store (2*16)(BP) +#define state2Store (3*16)(BP) +#define tmpStore (4*16)(BP) +#define ctr0Store (5*16)(BP) +#define ctr1Store (6*16)(BP) +#define ctr2Store (7*16)(BP) +#define ctr3Store (8*16)(BP) +#define A0 X0 +#define A1 X1 +#define A2 X2 +#define B0 X3 +#define B1 X4 +#define B2 X5 +#define C0 X6 +#define C1 X7 +#define C2 X8 +#define D0 X9 +#define D1 X10 +#define D2 X11 +#define T0 X12 +#define T1 X13 +#define T2 X14 +#define T3 X15 +#define A3 T0 +#define B3 T1 +#define C3 T2 +#define D3 T3 +// Register and stack allocation for the AVX2 code +#define rsStoreAVX2 (0*32)(BP) +#define state1StoreAVX2 (1*32)(BP) +#define state2StoreAVX2 (2*32)(BP) +#define ctr0StoreAVX2 (3*32)(BP) +#define ctr1StoreAVX2 (4*32)(BP) +#define ctr2StoreAVX2 (5*32)(BP) +#define ctr3StoreAVX2 (6*32)(BP) +#define tmpStoreAVX2 (7*32)(BP) // 256 bytes on stack +#define AA0 Y0 +#define AA1 Y5 +#define AA2 Y6 +#define AA3 Y7 +#define BB0 Y14 +#define BB1 Y9 +#define BB2 Y10 +#define BB3 Y11 +#define CC0 Y12 +#define CC1 Y13 +#define CC2 Y8 +#define CC3 Y15 +#define DD0 Y4 +#define DD1 Y1 +#define DD2 Y2 +#define DD3 Y3 +#define TT0 DD3 +#define TT1 AA3 +#define TT2 BB3 +#define TT3 CC3 +// ChaCha20 constants +DATA ·chacha20Constants<>+0x00(SB)/4, $0x61707865 +DATA ·chacha20Constants<>+0x04(SB)/4, $0x3320646e +DATA ·chacha20Constants<>+0x08(SB)/4, $0x79622d32 +DATA ·chacha20Constants<>+0x0c(SB)/4, $0x6b206574 +DATA ·chacha20Constants<>+0x10(SB)/4, $0x61707865 +DATA ·chacha20Constants<>+0x14(SB)/4, $0x3320646e +DATA ·chacha20Constants<>+0x18(SB)/4, $0x79622d32 +DATA ·chacha20Constants<>+0x1c(SB)/4, $0x6b206574 +// <<< 16 with PSHUFB +DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302 +DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A +DATA ·rol16<>+0x10(SB)/8, $0x0504070601000302 +DATA ·rol16<>+0x18(SB)/8, $0x0D0C0F0E09080B0A +// <<< 8 with PSHUFB +DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003 +DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B +DATA ·rol8<>+0x10(SB)/8, $0x0605040702010003 +DATA ·rol8<>+0x18(SB)/8, $0x0E0D0C0F0A09080B + +DATA ·avx2InitMask<>+0x00(SB)/8, $0x0 +DATA ·avx2InitMask<>+0x08(SB)/8, $0x0 +DATA ·avx2InitMask<>+0x10(SB)/8, $0x1 +DATA ·avx2InitMask<>+0x18(SB)/8, $0x0 + +DATA ·avx2IncMask<>+0x00(SB)/8, $0x2 +DATA ·avx2IncMask<>+0x08(SB)/8, $0x0 +DATA ·avx2IncMask<>+0x10(SB)/8, $0x2 +DATA ·avx2IncMask<>+0x18(SB)/8, $0x0 +// Poly1305 key clamp +DATA ·polyClampMask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF +DATA ·polyClampMask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC +DATA ·polyClampMask<>+0x10(SB)/8, $0xFFFFFFFFFFFFFFFF +DATA ·polyClampMask<>+0x18(SB)/8, $0xFFFFFFFFFFFFFFFF + +DATA ·sseIncMask<>+0x00(SB)/8, $0x1 +DATA ·sseIncMask<>+0x08(SB)/8, $0x0 +// To load/store the last < 16 bytes in a buffer +DATA ·andMask<>+0x00(SB)/8, $0x00000000000000ff +DATA ·andMask<>+0x08(SB)/8, $0x0000000000000000 +DATA ·andMask<>+0x10(SB)/8, $0x000000000000ffff +DATA ·andMask<>+0x18(SB)/8, $0x0000000000000000 +DATA ·andMask<>+0x20(SB)/8, $0x0000000000ffffff +DATA ·andMask<>+0x28(SB)/8, $0x0000000000000000 +DATA ·andMask<>+0x30(SB)/8, $0x00000000ffffffff +DATA ·andMask<>+0x38(SB)/8, $0x0000000000000000 +DATA ·andMask<>+0x40(SB)/8, $0x000000ffffffffff +DATA ·andMask<>+0x48(SB)/8, $0x0000000000000000 +DATA ·andMask<>+0x50(SB)/8, $0x0000ffffffffffff +DATA ·andMask<>+0x58(SB)/8, $0x0000000000000000 +DATA ·andMask<>+0x60(SB)/8, $0x00ffffffffffffff +DATA ·andMask<>+0x68(SB)/8, $0x0000000000000000 +DATA ·andMask<>+0x70(SB)/8, $0xffffffffffffffff +DATA ·andMask<>+0x78(SB)/8, $0x0000000000000000 +DATA ·andMask<>+0x80(SB)/8, $0xffffffffffffffff +DATA ·andMask<>+0x88(SB)/8, $0x00000000000000ff +DATA ·andMask<>+0x90(SB)/8, $0xffffffffffffffff +DATA ·andMask<>+0x98(SB)/8, $0x000000000000ffff +DATA ·andMask<>+0xa0(SB)/8, $0xffffffffffffffff +DATA ·andMask<>+0xa8(SB)/8, $0x0000000000ffffff +DATA ·andMask<>+0xb0(SB)/8, $0xffffffffffffffff +DATA ·andMask<>+0xb8(SB)/8, $0x00000000ffffffff +DATA ·andMask<>+0xc0(SB)/8, $0xffffffffffffffff +DATA ·andMask<>+0xc8(SB)/8, $0x000000ffffffffff +DATA ·andMask<>+0xd0(SB)/8, $0xffffffffffffffff +DATA ·andMask<>+0xd8(SB)/8, $0x0000ffffffffffff +DATA ·andMask<>+0xe0(SB)/8, $0xffffffffffffffff +DATA ·andMask<>+0xe8(SB)/8, $0x00ffffffffffffff + +GLOBL ·chacha20Constants<>(SB), (NOPTR+RODATA), $32 +GLOBL ·rol16<>(SB), (NOPTR+RODATA), $32 +GLOBL ·rol8<>(SB), (NOPTR+RODATA), $32 +GLOBL ·sseIncMask<>(SB), (NOPTR+RODATA), $16 +GLOBL ·avx2IncMask<>(SB), (NOPTR+RODATA), $32 +GLOBL ·avx2InitMask<>(SB), (NOPTR+RODATA), $32 +GLOBL ·polyClampMask<>(SB), (NOPTR+RODATA), $32 +GLOBL ·andMask<>(SB), (NOPTR+RODATA), $240 +// No PALIGNR in Go ASM yet (but VPALIGNR is present). +#define shiftB0Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xdb; BYTE $0x04 // PALIGNR $4, X3, X3 +#define shiftB1Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xe4; BYTE $0x04 // PALIGNR $4, X4, X4 +#define shiftB2Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xed; BYTE $0x04 // PALIGNR $4, X5, X5 +#define shiftB3Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xed; BYTE $0x04 // PALIGNR $4, X13, X13 +#define shiftC0Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xf6; BYTE $0x08 // PALIGNR $8, X6, X6 +#define shiftC1Left BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xff; BYTE $0x08 // PALIGNR $8, X7, X7 +#define shiftC2Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xc0; BYTE $0x08 // PALIGNR $8, X8, X8 +#define shiftC3Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xf6; BYTE $0x08 // PALIGNR $8, X14, X14 +#define shiftD0Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xc9; BYTE $0x0c // PALIGNR $12, X9, X9 +#define shiftD1Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xd2; BYTE $0x0c // PALIGNR $12, X10, X10 +#define shiftD2Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xdb; BYTE $0x0c // PALIGNR $12, X11, X11 +#define shiftD3Left BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xff; BYTE $0x0c // PALIGNR $12, X15, X15 +#define shiftB0Right BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xdb; BYTE $0x0c // PALIGNR $12, X3, X3 +#define shiftB1Right BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xe4; BYTE $0x0c // PALIGNR $12, X4, X4 +#define shiftB2Right BYTE $0x66; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xed; BYTE $0x0c // PALIGNR $12, X5, X5 +#define shiftB3Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xed; BYTE $0x0c // PALIGNR $12, X13, X13 +#define shiftC0Right shiftC0Left +#define shiftC1Right shiftC1Left +#define shiftC2Right shiftC2Left +#define shiftC3Right shiftC3Left +#define shiftD0Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xc9; BYTE $0x04 // PALIGNR $4, X9, X9 +#define shiftD1Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xd2; BYTE $0x04 // PALIGNR $4, X10, X10 +#define shiftD2Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xdb; BYTE $0x04 // PALIGNR $4, X11, X11 +#define shiftD3Right BYTE $0x66; BYTE $0x45; BYTE $0x0f; BYTE $0x3a; BYTE $0x0f; BYTE $0xff; BYTE $0x04 // PALIGNR $4, X15, X15 +// Some macros +#define chachaQR(A, B, C, D, T) \ + PADDD B, A; PXOR A, D; PSHUFB ·rol16<>(SB), D \ + PADDD D, C; PXOR C, B; MOVO B, T; PSLLL $12, T; PSRLL $20, B; PXOR T, B \ + PADDD B, A; PXOR A, D; PSHUFB ·rol8<>(SB), D \ + PADDD D, C; PXOR C, B; MOVO B, T; PSLLL $7, T; PSRLL $25, B; PXOR T, B + +#define chachaQR_AVX2(A, B, C, D, T) \ + VPADDD B, A, A; VPXOR A, D, D; VPSHUFB ·rol16<>(SB), D, D \ + VPADDD D, C, C; VPXOR C, B, B; VPSLLD $12, B, T; VPSRLD $20, B, B; VPXOR T, B, B \ + VPADDD B, A, A; VPXOR A, D, D; VPSHUFB ·rol8<>(SB), D, D \ + VPADDD D, C, C; VPXOR C, B, B; VPSLLD $7, B, T; VPSRLD $25, B, B; VPXOR T, B, B + +#define polyAdd(S) ADDQ S, acc0; ADCQ 8+S, acc1; ADCQ $1, acc2 +#define polyMulStage1 MOVQ (0*8)(BP), AX; MOVQ AX, t2; MULQ acc0; MOVQ AX, t0; MOVQ DX, t1; MOVQ (0*8)(BP), AX; MULQ acc1; IMULQ acc2, t2; ADDQ AX, t1; ADCQ DX, t2 +#define polyMulStage2 MOVQ (1*8)(BP), AX; MOVQ AX, t3; MULQ acc0; ADDQ AX, t1; ADCQ $0, DX; MOVQ DX, acc0; MOVQ (1*8)(BP), AX; MULQ acc1; ADDQ AX, t2; ADCQ $0, DX +#define polyMulStage3 IMULQ acc2, t3; ADDQ acc0, t2; ADCQ DX, t3 +#define polyMulReduceStage MOVQ t0, acc0; MOVQ t1, acc1; MOVQ t2, acc2; ANDQ $3, acc2; MOVQ t2, t0; ANDQ $-4, t0; MOVQ t3, t1; SHRQ $2, t2:t3; SHRQ $2, t3; ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $0, acc2; ADDQ t2, acc0; ADCQ t3, acc1; ADCQ $0, acc2 + +#define polyMulStage1_AVX2 MOVQ (0*8)(BP), DX; MOVQ DX, t2; MULXQ acc0, t0, t1; IMULQ acc2, t2; MULXQ acc1, AX, DX; ADDQ AX, t1; ADCQ DX, t2 +#define polyMulStage2_AVX2 MOVQ (1*8)(BP), DX; MULXQ acc0, acc0, AX; ADDQ acc0, t1; MULXQ acc1, acc1, t3; ADCQ acc1, t2; ADCQ $0, t3 +#define polyMulStage3_AVX2 IMULQ acc2, DX; ADDQ AX, t2; ADCQ DX, t3 + +#define polyMul polyMulStage1; polyMulStage2; polyMulStage3; polyMulReduceStage +#define polyMulAVX2 polyMulStage1_AVX2; polyMulStage2_AVX2; polyMulStage3_AVX2; polyMulReduceStage +// ---------------------------------------------------------------------------- +TEXT polyHashADInternal<>(SB), NOSPLIT, $0 + // adp points to beginning of additional data + // itr2 holds ad length + XORQ acc0, acc0 + XORQ acc1, acc1 + XORQ acc2, acc2 + CMPQ itr2, $13 + JNE hashADLoop + +openFastTLSAD: + // Special treatment for the TLS case of 13 bytes + MOVQ (adp), acc0 + MOVQ 5(adp), acc1 + SHRQ $24, acc1 + MOVQ $1, acc2 + polyMul + RET + +hashADLoop: + // Hash in 16 byte chunks + CMPQ itr2, $16 + JB hashADTail + polyAdd(0(adp)) + LEAQ (1*16)(adp), adp + SUBQ $16, itr2 + polyMul + JMP hashADLoop + +hashADTail: + CMPQ itr2, $0 + JE hashADDone + + // Hash last < 16 byte tail + XORQ t0, t0 + XORQ t1, t1 + XORQ t2, t2 + ADDQ itr2, adp + +hashADTailLoop: + SHLQ $8, t1:t0 + SHLQ $8, t0 + MOVB -1(adp), t2 + XORQ t2, t0 + DECQ adp + DECQ itr2 + JNE hashADTailLoop + +hashADTailFinish: + ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $1, acc2 + polyMul + + // Finished AD +hashADDone: + RET + +// ---------------------------------------------------------------------------- +// func chacha20Poly1305Open(dst, key, src, ad []byte) bool +TEXT ·chacha20Poly1305Open(SB), 0, $288-97 + // For aligned stack access + MOVQ SP, BP + ADDQ $32, BP + ANDQ $-32, BP + MOVQ dst+0(FP), oup + MOVQ key+24(FP), keyp + MOVQ src+48(FP), inp + MOVQ src_len+56(FP), inl + MOVQ ad+72(FP), adp + + // Check for AVX2 support + CMPB ·useAVX2(SB), $1 + JE chacha20Poly1305Open_AVX2 + + // Special optimization, for very short buffers + CMPQ inl, $128 + JBE openSSE128 // About 16% faster + + // For long buffers, prepare the poly key first + MOVOU ·chacha20Constants<>(SB), A0 + MOVOU (1*16)(keyp), B0 + MOVOU (2*16)(keyp), C0 + MOVOU (3*16)(keyp), D0 + MOVO D0, T1 + + // Store state on stack for future use + MOVO B0, state1Store + MOVO C0, state2Store + MOVO D0, ctr3Store + MOVQ $10, itr2 + +openSSEPreparePolyKey: + chachaQR(A0, B0, C0, D0, T0) + shiftB0Left; shiftC0Left; shiftD0Left + chachaQR(A0, B0, C0, D0, T0) + shiftB0Right; shiftC0Right; shiftD0Right + DECQ itr2 + JNE openSSEPreparePolyKey + + // A0|B0 hold the Poly1305 32-byte key, C0,D0 can be discarded + PADDL ·chacha20Constants<>(SB), A0; PADDL state1Store, B0 + + // Clamp and store the key + PAND ·polyClampMask<>(SB), A0 + MOVO A0, rStore; MOVO B0, sStore + + // Hash AAD + MOVQ ad_len+80(FP), itr2 + CALL polyHashADInternal<>(SB) + +openSSEMainLoop: + CMPQ inl, $256 + JB openSSEMainLoopDone + + // Load state, increment counter blocks + MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0 + MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1 + MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2 + MOVO A2, A3; MOVO B2, B3; MOVO C2, C3; MOVO D2, D3; PADDL ·sseIncMask<>(SB), D3 + + // Store counters + MOVO D0, ctr0Store; MOVO D1, ctr1Store; MOVO D2, ctr2Store; MOVO D3, ctr3Store + + // There are 10 ChaCha20 iterations of 2QR each, so for 6 iterations we hash 2 blocks, and for the remaining 4 only 1 block - for a total of 16 + MOVQ $4, itr1 + MOVQ inp, itr2 + +openSSEInternalLoop: + MOVO C3, tmpStore + chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3) + MOVO tmpStore, C3 + MOVO C1, tmpStore + chachaQR(A3, B3, C3, D3, C1) + MOVO tmpStore, C1 + polyAdd(0(itr2)) + shiftB0Left; shiftB1Left; shiftB2Left; shiftB3Left + shiftC0Left; shiftC1Left; shiftC2Left; shiftC3Left + shiftD0Left; shiftD1Left; shiftD2Left; shiftD3Left + polyMulStage1 + polyMulStage2 + LEAQ (2*8)(itr2), itr2 + MOVO C3, tmpStore + chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3) + MOVO tmpStore, C3 + MOVO C1, tmpStore + polyMulStage3 + chachaQR(A3, B3, C3, D3, C1) + MOVO tmpStore, C1 + polyMulReduceStage + shiftB0Right; shiftB1Right; shiftB2Right; shiftB3Right + shiftC0Right; shiftC1Right; shiftC2Right; shiftC3Right + shiftD0Right; shiftD1Right; shiftD2Right; shiftD3Right + DECQ itr1 + JGE openSSEInternalLoop + + polyAdd(0(itr2)) + polyMul + LEAQ (2*8)(itr2), itr2 + + CMPQ itr1, $-6 + JG openSSEInternalLoop + + // Add in the state + PADDD ·chacha20Constants<>(SB), A0; PADDD ·chacha20Constants<>(SB), A1; PADDD ·chacha20Constants<>(SB), A2; PADDD ·chacha20Constants<>(SB), A3 + PADDD state1Store, B0; PADDD state1Store, B1; PADDD state1Store, B2; PADDD state1Store, B3 + PADDD state2Store, C0; PADDD state2Store, C1; PADDD state2Store, C2; PADDD state2Store, C3 + PADDD ctr0Store, D0; PADDD ctr1Store, D1; PADDD ctr2Store, D2; PADDD ctr3Store, D3 + + // Load - xor - store + MOVO D3, tmpStore + MOVOU (0*16)(inp), D3; PXOR D3, A0; MOVOU A0, (0*16)(oup) + MOVOU (1*16)(inp), D3; PXOR D3, B0; MOVOU B0, (1*16)(oup) + MOVOU (2*16)(inp), D3; PXOR D3, C0; MOVOU C0, (2*16)(oup) + MOVOU (3*16)(inp), D3; PXOR D3, D0; MOVOU D0, (3*16)(oup) + MOVOU (4*16)(inp), D0; PXOR D0, A1; MOVOU A1, (4*16)(oup) + MOVOU (5*16)(inp), D0; PXOR D0, B1; MOVOU B1, (5*16)(oup) + MOVOU (6*16)(inp), D0; PXOR D0, C1; MOVOU C1, (6*16)(oup) + MOVOU (7*16)(inp), D0; PXOR D0, D1; MOVOU D1, (7*16)(oup) + MOVOU (8*16)(inp), D0; PXOR D0, A2; MOVOU A2, (8*16)(oup) + MOVOU (9*16)(inp), D0; PXOR D0, B2; MOVOU B2, (9*16)(oup) + MOVOU (10*16)(inp), D0; PXOR D0, C2; MOVOU C2, (10*16)(oup) + MOVOU (11*16)(inp), D0; PXOR D0, D2; MOVOU D2, (11*16)(oup) + MOVOU (12*16)(inp), D0; PXOR D0, A3; MOVOU A3, (12*16)(oup) + MOVOU (13*16)(inp), D0; PXOR D0, B3; MOVOU B3, (13*16)(oup) + MOVOU (14*16)(inp), D0; PXOR D0, C3; MOVOU C3, (14*16)(oup) + MOVOU (15*16)(inp), D0; PXOR tmpStore, D0; MOVOU D0, (15*16)(oup) + LEAQ 256(inp), inp + LEAQ 256(oup), oup + SUBQ $256, inl + JMP openSSEMainLoop + +openSSEMainLoopDone: + // Handle the various tail sizes efficiently + TESTQ inl, inl + JE openSSEFinalize + CMPQ inl, $64 + JBE openSSETail64 + CMPQ inl, $128 + JBE openSSETail128 + CMPQ inl, $192 + JBE openSSETail192 + JMP openSSETail256 + +openSSEFinalize: + // Hash in the PT, AAD lengths + ADDQ ad_len+80(FP), acc0; ADCQ src_len+56(FP), acc1; ADCQ $1, acc2 + polyMul + + // Final reduce + MOVQ acc0, t0 + MOVQ acc1, t1 + MOVQ acc2, t2 + SUBQ $-5, acc0 + SBBQ $-1, acc1 + SBBQ $3, acc2 + CMOVQCS t0, acc0 + CMOVQCS t1, acc1 + CMOVQCS t2, acc2 + + // Add in the "s" part of the key + ADDQ 0+sStore, acc0 + ADCQ 8+sStore, acc1 + + // Finally, constant time compare to the tag at the end of the message + XORQ AX, AX + MOVQ $1, DX + XORQ (0*8)(inp), acc0 + XORQ (1*8)(inp), acc1 + ORQ acc1, acc0 + CMOVQEQ DX, AX + + // Return true iff tags are equal + MOVB AX, ret+96(FP) + RET + +// ---------------------------------------------------------------------------- +// Special optimization for buffers smaller than 129 bytes +openSSE128: + // For up to 128 bytes of ciphertext and 64 bytes for the poly key, we require to process three blocks + MOVOU ·chacha20Constants<>(SB), A0; MOVOU (1*16)(keyp), B0; MOVOU (2*16)(keyp), C0; MOVOU (3*16)(keyp), D0 + MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1 + MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2 + MOVO B0, T1; MOVO C0, T2; MOVO D1, T3 + MOVQ $10, itr2 + +openSSE128InnerCipherLoop: + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0) + shiftB0Left; shiftB1Left; shiftB2Left + shiftC0Left; shiftC1Left; shiftC2Left + shiftD0Left; shiftD1Left; shiftD2Left + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0) + shiftB0Right; shiftB1Right; shiftB2Right + shiftC0Right; shiftC1Right; shiftC2Right + shiftD0Right; shiftD1Right; shiftD2Right + DECQ itr2 + JNE openSSE128InnerCipherLoop + + // A0|B0 hold the Poly1305 32-byte key, C0,D0 can be discarded + PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1; PADDL ·chacha20Constants<>(SB), A2 + PADDL T1, B0; PADDL T1, B1; PADDL T1, B2 + PADDL T2, C1; PADDL T2, C2 + PADDL T3, D1; PADDL ·sseIncMask<>(SB), T3; PADDL T3, D2 + + // Clamp and store the key + PAND ·polyClampMask<>(SB), A0 + MOVOU A0, rStore; MOVOU B0, sStore + + // Hash + MOVQ ad_len+80(FP), itr2 + CALL polyHashADInternal<>(SB) + +openSSE128Open: + CMPQ inl, $16 + JB openSSETail16 + SUBQ $16, inl + + // Load for hashing + polyAdd(0(inp)) + + // Load for decryption + MOVOU (inp), T0; PXOR T0, A1; MOVOU A1, (oup) + LEAQ (1*16)(inp), inp + LEAQ (1*16)(oup), oup + polyMul + + // Shift the stream "left" + MOVO B1, A1 + MOVO C1, B1 + MOVO D1, C1 + MOVO A2, D1 + MOVO B2, A2 + MOVO C2, B2 + MOVO D2, C2 + JMP openSSE128Open + +openSSETail16: + TESTQ inl, inl + JE openSSEFinalize + + // We can safely load the CT from the end, because it is padded with the MAC + MOVQ inl, itr2 + SHLQ $4, itr2 + LEAQ ·andMask<>(SB), t0 + MOVOU (inp), T0 + ADDQ inl, inp + PAND -16(t0)(itr2*1), T0 + MOVO T0, 0+tmpStore + MOVQ T0, t0 + MOVQ 8+tmpStore, t1 + PXOR A1, T0 + + // We can only store one byte at a time, since plaintext can be shorter than 16 bytes +openSSETail16Store: + MOVQ T0, t3 + MOVB t3, (oup) + PSRLDQ $1, T0 + INCQ oup + DECQ inl + JNE openSSETail16Store + ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $1, acc2 + polyMul + JMP openSSEFinalize + +// ---------------------------------------------------------------------------- +// Special optimization for the last 64 bytes of ciphertext +openSSETail64: + // Need to decrypt up to 64 bytes - prepare single block + MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr0Store + XORQ itr2, itr2 + MOVQ inl, itr1 + CMPQ itr1, $16 + JB openSSETail64LoopB + +openSSETail64LoopA: + // Perform ChaCha rounds, while hashing the remaining input + polyAdd(0(inp)(itr2*1)) + polyMul + SUBQ $16, itr1 + +openSSETail64LoopB: + ADDQ $16, itr2 + chachaQR(A0, B0, C0, D0, T0) + shiftB0Left; shiftC0Left; shiftD0Left + chachaQR(A0, B0, C0, D0, T0) + shiftB0Right; shiftC0Right; shiftD0Right + + CMPQ itr1, $16 + JAE openSSETail64LoopA + + CMPQ itr2, $160 + JNE openSSETail64LoopB + + PADDL ·chacha20Constants<>(SB), A0; PADDL state1Store, B0; PADDL state2Store, C0; PADDL ctr0Store, D0 + +openSSETail64DecLoop: + CMPQ inl, $16 + JB openSSETail64DecLoopDone + SUBQ $16, inl + MOVOU (inp), T0 + PXOR T0, A0 + MOVOU A0, (oup) + LEAQ 16(inp), inp + LEAQ 16(oup), oup + MOVO B0, A0 + MOVO C0, B0 + MOVO D0, C0 + JMP openSSETail64DecLoop + +openSSETail64DecLoopDone: + MOVO A0, A1 + JMP openSSETail16 + +// ---------------------------------------------------------------------------- +// Special optimization for the last 128 bytes of ciphertext +openSSETail128: + // Need to decrypt up to 128 bytes - prepare two blocks + MOVO ·chacha20Constants<>(SB), A1; MOVO state1Store, B1; MOVO state2Store, C1; MOVO ctr3Store, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr0Store + MOVO A1, A0; MOVO B1, B0; MOVO C1, C0; MOVO D1, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr1Store + XORQ itr2, itr2 + MOVQ inl, itr1 + ANDQ $-16, itr1 + +openSSETail128LoopA: + // Perform ChaCha rounds, while hashing the remaining input + polyAdd(0(inp)(itr2*1)) + polyMul + +openSSETail128LoopB: + ADDQ $16, itr2 + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0) + shiftB0Left; shiftC0Left; shiftD0Left + shiftB1Left; shiftC1Left; shiftD1Left + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0) + shiftB0Right; shiftC0Right; shiftD0Right + shiftB1Right; shiftC1Right; shiftD1Right + + CMPQ itr2, itr1 + JB openSSETail128LoopA + + CMPQ itr2, $160 + JNE openSSETail128LoopB + + PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1 + PADDL state1Store, B0; PADDL state1Store, B1 + PADDL state2Store, C0; PADDL state2Store, C1 + PADDL ctr1Store, D0; PADDL ctr0Store, D1 + + MOVOU (0*16)(inp), T0; MOVOU (1*16)(inp), T1; MOVOU (2*16)(inp), T2; MOVOU (3*16)(inp), T3 + PXOR T0, A1; PXOR T1, B1; PXOR T2, C1; PXOR T3, D1 + MOVOU A1, (0*16)(oup); MOVOU B1, (1*16)(oup); MOVOU C1, (2*16)(oup); MOVOU D1, (3*16)(oup) + + SUBQ $64, inl + LEAQ 64(inp), inp + LEAQ 64(oup), oup + JMP openSSETail64DecLoop + +// ---------------------------------------------------------------------------- +// Special optimization for the last 192 bytes of ciphertext +openSSETail192: + // Need to decrypt up to 192 bytes - prepare three blocks + MOVO ·chacha20Constants<>(SB), A2; MOVO state1Store, B2; MOVO state2Store, C2; MOVO ctr3Store, D2; PADDL ·sseIncMask<>(SB), D2; MOVO D2, ctr0Store + MOVO A2, A1; MOVO B2, B1; MOVO C2, C1; MOVO D2, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr1Store + MOVO A1, A0; MOVO B1, B0; MOVO C1, C0; MOVO D1, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr2Store + + MOVQ inl, itr1 + MOVQ $160, itr2 + CMPQ itr1, $160 + CMOVQGT itr2, itr1 + ANDQ $-16, itr1 + XORQ itr2, itr2 + +openSSLTail192LoopA: + // Perform ChaCha rounds, while hashing the remaining input + polyAdd(0(inp)(itr2*1)) + polyMul + +openSSLTail192LoopB: + ADDQ $16, itr2 + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0) + shiftB0Left; shiftC0Left; shiftD0Left + shiftB1Left; shiftC1Left; shiftD1Left + shiftB2Left; shiftC2Left; shiftD2Left + + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0) + shiftB0Right; shiftC0Right; shiftD0Right + shiftB1Right; shiftC1Right; shiftD1Right + shiftB2Right; shiftC2Right; shiftD2Right + + CMPQ itr2, itr1 + JB openSSLTail192LoopA + + CMPQ itr2, $160 + JNE openSSLTail192LoopB + + CMPQ inl, $176 + JB openSSLTail192Store + + polyAdd(160(inp)) + polyMul + + CMPQ inl, $192 + JB openSSLTail192Store + + polyAdd(176(inp)) + polyMul + +openSSLTail192Store: + PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1; PADDL ·chacha20Constants<>(SB), A2 + PADDL state1Store, B0; PADDL state1Store, B1; PADDL state1Store, B2 + PADDL state2Store, C0; PADDL state2Store, C1; PADDL state2Store, C2 + PADDL ctr2Store, D0; PADDL ctr1Store, D1; PADDL ctr0Store, D2 + + MOVOU (0*16)(inp), T0; MOVOU (1*16)(inp), T1; MOVOU (2*16)(inp), T2; MOVOU (3*16)(inp), T3 + PXOR T0, A2; PXOR T1, B2; PXOR T2, C2; PXOR T3, D2 + MOVOU A2, (0*16)(oup); MOVOU B2, (1*16)(oup); MOVOU C2, (2*16)(oup); MOVOU D2, (3*16)(oup) + + MOVOU (4*16)(inp), T0; MOVOU (5*16)(inp), T1; MOVOU (6*16)(inp), T2; MOVOU (7*16)(inp), T3 + PXOR T0, A1; PXOR T1, B1; PXOR T2, C1; PXOR T3, D1 + MOVOU A1, (4*16)(oup); MOVOU B1, (5*16)(oup); MOVOU C1, (6*16)(oup); MOVOU D1, (7*16)(oup) + + SUBQ $128, inl + LEAQ 128(inp), inp + LEAQ 128(oup), oup + JMP openSSETail64DecLoop + +// ---------------------------------------------------------------------------- +// Special optimization for the last 256 bytes of ciphertext +openSSETail256: + // Need to decrypt up to 256 bytes - prepare four blocks + MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0 + MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1 + MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2 + MOVO A2, A3; MOVO B2, B3; MOVO C2, C3; MOVO D2, D3; PADDL ·sseIncMask<>(SB), D3 + + // Store counters + MOVO D0, ctr0Store; MOVO D1, ctr1Store; MOVO D2, ctr2Store; MOVO D3, ctr3Store + XORQ itr2, itr2 + +openSSETail256Loop: + // This loop inteleaves 8 ChaCha quarter rounds with 1 poly multiplication + polyAdd(0(inp)(itr2*1)) + MOVO C3, tmpStore + chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3) + MOVO tmpStore, C3 + MOVO C1, tmpStore + chachaQR(A3, B3, C3, D3, C1) + MOVO tmpStore, C1 + shiftB0Left; shiftB1Left; shiftB2Left; shiftB3Left + shiftC0Left; shiftC1Left; shiftC2Left; shiftC3Left + shiftD0Left; shiftD1Left; shiftD2Left; shiftD3Left + polyMulStage1 + polyMulStage2 + MOVO C3, tmpStore + chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3) + MOVO tmpStore, C3 + MOVO C1, tmpStore + chachaQR(A3, B3, C3, D3, C1) + MOVO tmpStore, C1 + polyMulStage3 + polyMulReduceStage + shiftB0Right; shiftB1Right; shiftB2Right; shiftB3Right + shiftC0Right; shiftC1Right; shiftC2Right; shiftC3Right + shiftD0Right; shiftD1Right; shiftD2Right; shiftD3Right + ADDQ $2*8, itr2 + CMPQ itr2, $160 + JB openSSETail256Loop + MOVQ inl, itr1 + ANDQ $-16, itr1 + +openSSETail256HashLoop: + polyAdd(0(inp)(itr2*1)) + polyMul + ADDQ $2*8, itr2 + CMPQ itr2, itr1 + JB openSSETail256HashLoop + + // Add in the state + PADDD ·chacha20Constants<>(SB), A0; PADDD ·chacha20Constants<>(SB), A1; PADDD ·chacha20Constants<>(SB), A2; PADDD ·chacha20Constants<>(SB), A3 + PADDD state1Store, B0; PADDD state1Store, B1; PADDD state1Store, B2; PADDD state1Store, B3 + PADDD state2Store, C0; PADDD state2Store, C1; PADDD state2Store, C2; PADDD state2Store, C3 + PADDD ctr0Store, D0; PADDD ctr1Store, D1; PADDD ctr2Store, D2; PADDD ctr3Store, D3 + MOVO D3, tmpStore + + // Load - xor - store + MOVOU (0*16)(inp), D3; PXOR D3, A0 + MOVOU (1*16)(inp), D3; PXOR D3, B0 + MOVOU (2*16)(inp), D3; PXOR D3, C0 + MOVOU (3*16)(inp), D3; PXOR D3, D0 + MOVOU A0, (0*16)(oup) + MOVOU B0, (1*16)(oup) + MOVOU C0, (2*16)(oup) + MOVOU D0, (3*16)(oup) + MOVOU (4*16)(inp), A0; MOVOU (5*16)(inp), B0; MOVOU (6*16)(inp), C0; MOVOU (7*16)(inp), D0 + PXOR A0, A1; PXOR B0, B1; PXOR C0, C1; PXOR D0, D1 + MOVOU A1, (4*16)(oup); MOVOU B1, (5*16)(oup); MOVOU C1, (6*16)(oup); MOVOU D1, (7*16)(oup) + MOVOU (8*16)(inp), A0; MOVOU (9*16)(inp), B0; MOVOU (10*16)(inp), C0; MOVOU (11*16)(inp), D0 + PXOR A0, A2; PXOR B0, B2; PXOR C0, C2; PXOR D0, D2 + MOVOU A2, (8*16)(oup); MOVOU B2, (9*16)(oup); MOVOU C2, (10*16)(oup); MOVOU D2, (11*16)(oup) + LEAQ 192(inp), inp + LEAQ 192(oup), oup + SUBQ $192, inl + MOVO A3, A0 + MOVO B3, B0 + MOVO C3, C0 + MOVO tmpStore, D0 + + JMP openSSETail64DecLoop + +// ---------------------------------------------------------------------------- +// ------------------------- AVX2 Code ---------------------------------------- +chacha20Poly1305Open_AVX2: + VZEROUPPER + VMOVDQU ·chacha20Constants<>(SB), AA0 + BYTE $0xc4; BYTE $0x42; BYTE $0x7d; BYTE $0x5a; BYTE $0x70; BYTE $0x10 // broadcasti128 16(r8), ymm14 + BYTE $0xc4; BYTE $0x42; BYTE $0x7d; BYTE $0x5a; BYTE $0x60; BYTE $0x20 // broadcasti128 32(r8), ymm12 + BYTE $0xc4; BYTE $0xc2; BYTE $0x7d; BYTE $0x5a; BYTE $0x60; BYTE $0x30 // broadcasti128 48(r8), ymm4 + VPADDD ·avx2InitMask<>(SB), DD0, DD0 + + // Special optimization, for very short buffers + CMPQ inl, $192 + JBE openAVX2192 + CMPQ inl, $320 + JBE openAVX2320 + + // For the general key prepare the key first - as a byproduct we have 64 bytes of cipher stream + VMOVDQA BB0, state1StoreAVX2 + VMOVDQA CC0, state2StoreAVX2 + VMOVDQA DD0, ctr3StoreAVX2 + MOVQ $10, itr2 + +openAVX2PreparePolyKey: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0) + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $12, DD0, DD0, DD0 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0) + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $4, DD0, DD0, DD0 + DECQ itr2 + JNE openAVX2PreparePolyKey + + VPADDD ·chacha20Constants<>(SB), AA0, AA0 + VPADDD state1StoreAVX2, BB0, BB0 + VPADDD state2StoreAVX2, CC0, CC0 + VPADDD ctr3StoreAVX2, DD0, DD0 + + VPERM2I128 $0x02, AA0, BB0, TT0 + + // Clamp and store poly key + VPAND ·polyClampMask<>(SB), TT0, TT0 + VMOVDQA TT0, rsStoreAVX2 + + // Stream for the first 64 bytes + VPERM2I128 $0x13, AA0, BB0, AA0 + VPERM2I128 $0x13, CC0, DD0, BB0 + + // Hash AD + first 64 bytes + MOVQ ad_len+80(FP), itr2 + CALL polyHashADInternal<>(SB) + XORQ itr1, itr1 + +openAVX2InitialHash64: + polyAdd(0(inp)(itr1*1)) + polyMulAVX2 + ADDQ $16, itr1 + CMPQ itr1, $64 + JNE openAVX2InitialHash64 + + // Decrypt the first 64 bytes + VPXOR (0*32)(inp), AA0, AA0 + VPXOR (1*32)(inp), BB0, BB0 + VMOVDQU AA0, (0*32)(oup) + VMOVDQU BB0, (1*32)(oup) + LEAQ (2*32)(inp), inp + LEAQ (2*32)(oup), oup + SUBQ $64, inl + +openAVX2MainLoop: + CMPQ inl, $512 + JB openAVX2MainLoopDone + + // Load state, increment counter blocks, store the incremented counters + VMOVDQU ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3 + VMOVDQA ctr3StoreAVX2, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3 + VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2 + XORQ itr1, itr1 + +openAVX2InternalLoop: + // Lets just say this spaghetti loop interleaves 2 quarter rounds with 3 poly multiplications + // Effectively per 512 bytes of stream we hash 480 bytes of ciphertext + polyAdd(0*8(inp)(itr1*1)) + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + polyMulStage1_AVX2 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + polyMulStage2_AVX2 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + polyMulStage3_AVX2 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyMulReduceStage + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3 + polyAdd(2*8(inp)(itr1*1)) + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + polyMulStage1_AVX2 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyMulStage2_AVX2 + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $4, BB3, BB3, BB3 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2; VPALIGNR $12, DD3, DD3, DD3 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + polyMulStage3_AVX2 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + polyMulReduceStage + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + polyAdd(4*8(inp)(itr1*1)) + LEAQ (6*8)(itr1), itr1 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyMulStage1_AVX2 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + polyMulStage2_AVX2 + VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + polyMulStage3_AVX2 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyMulReduceStage + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $12, BB3, BB3, BB3 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2; VPALIGNR $4, DD3, DD3, DD3 + CMPQ itr1, $480 + JNE openAVX2InternalLoop + + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3 + VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3 + VMOVDQA CC3, tmpStoreAVX2 + + // We only hashed 480 of the 512 bytes available - hash the remaining 32 here + polyAdd(480(inp)) + polyMulAVX2 + VPERM2I128 $0x02, AA0, BB0, CC3; VPERM2I128 $0x13, AA0, BB0, BB0; VPERM2I128 $0x02, CC0, DD0, AA0; VPERM2I128 $0x13, CC0, DD0, CC0 + VPXOR (0*32)(inp), CC3, CC3; VPXOR (1*32)(inp), AA0, AA0; VPXOR (2*32)(inp), BB0, BB0; VPXOR (3*32)(inp), CC0, CC0 + VMOVDQU CC3, (0*32)(oup); VMOVDQU AA0, (1*32)(oup); VMOVDQU BB0, (2*32)(oup); VMOVDQU CC0, (3*32)(oup) + VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0 + VPXOR (4*32)(inp), AA0, AA0; VPXOR (5*32)(inp), BB0, BB0; VPXOR (6*32)(inp), CC0, CC0; VPXOR (7*32)(inp), DD0, DD0 + VMOVDQU AA0, (4*32)(oup); VMOVDQU BB0, (5*32)(oup); VMOVDQU CC0, (6*32)(oup); VMOVDQU DD0, (7*32)(oup) + + // and here + polyAdd(496(inp)) + polyMulAVX2 + VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0 + VPXOR (8*32)(inp), AA0, AA0; VPXOR (9*32)(inp), BB0, BB0; VPXOR (10*32)(inp), CC0, CC0; VPXOR (11*32)(inp), DD0, DD0 + VMOVDQU AA0, (8*32)(oup); VMOVDQU BB0, (9*32)(oup); VMOVDQU CC0, (10*32)(oup); VMOVDQU DD0, (11*32)(oup) + VPERM2I128 $0x02, AA3, BB3, AA0; VPERM2I128 $0x02, tmpStoreAVX2, DD3, BB0; VPERM2I128 $0x13, AA3, BB3, CC0; VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0 + VPXOR (12*32)(inp), AA0, AA0; VPXOR (13*32)(inp), BB0, BB0; VPXOR (14*32)(inp), CC0, CC0; VPXOR (15*32)(inp), DD0, DD0 + VMOVDQU AA0, (12*32)(oup); VMOVDQU BB0, (13*32)(oup); VMOVDQU CC0, (14*32)(oup); VMOVDQU DD0, (15*32)(oup) + LEAQ (32*16)(inp), inp + LEAQ (32*16)(oup), oup + SUBQ $(32*16), inl + JMP openAVX2MainLoop + +openAVX2MainLoopDone: + // Handle the various tail sizes efficiently + TESTQ inl, inl + JE openSSEFinalize + CMPQ inl, $128 + JBE openAVX2Tail128 + CMPQ inl, $256 + JBE openAVX2Tail256 + CMPQ inl, $384 + JBE openAVX2Tail384 + JMP openAVX2Tail512 + +// ---------------------------------------------------------------------------- +// Special optimization for buffers smaller than 193 bytes +openAVX2192: + // For up to 192 bytes of ciphertext and 64 bytes for the poly key, we process four blocks + VMOVDQA AA0, AA1 + VMOVDQA BB0, BB1 + VMOVDQA CC0, CC1 + VPADDD ·avx2IncMask<>(SB), DD0, DD1 + VMOVDQA AA0, AA2 + VMOVDQA BB0, BB2 + VMOVDQA CC0, CC2 + VMOVDQA DD0, DD2 + VMOVDQA DD1, TT3 + MOVQ $10, itr2 + +openAVX2192InnerCipherLoop: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1 + DECQ itr2 + JNE openAVX2192InnerCipherLoop + VPADDD AA2, AA0, AA0; VPADDD AA2, AA1, AA1 + VPADDD BB2, BB0, BB0; VPADDD BB2, BB1, BB1 + VPADDD CC2, CC0, CC0; VPADDD CC2, CC1, CC1 + VPADDD DD2, DD0, DD0; VPADDD TT3, DD1, DD1 + VPERM2I128 $0x02, AA0, BB0, TT0 + + // Clamp and store poly key + VPAND ·polyClampMask<>(SB), TT0, TT0 + VMOVDQA TT0, rsStoreAVX2 + + // Stream for up to 192 bytes + VPERM2I128 $0x13, AA0, BB0, AA0 + VPERM2I128 $0x13, CC0, DD0, BB0 + VPERM2I128 $0x02, AA1, BB1, CC0 + VPERM2I128 $0x02, CC1, DD1, DD0 + VPERM2I128 $0x13, AA1, BB1, AA1 + VPERM2I128 $0x13, CC1, DD1, BB1 + +openAVX2ShortOpen: + // Hash + MOVQ ad_len+80(FP), itr2 + CALL polyHashADInternal<>(SB) + +openAVX2ShortOpenLoop: + CMPQ inl, $32 + JB openAVX2ShortTail32 + SUBQ $32, inl + + // Load for hashing + polyAdd(0*8(inp)) + polyMulAVX2 + polyAdd(2*8(inp)) + polyMulAVX2 + + // Load for decryption + VPXOR (inp), AA0, AA0 + VMOVDQU AA0, (oup) + LEAQ (1*32)(inp), inp + LEAQ (1*32)(oup), oup + + // Shift stream left + VMOVDQA BB0, AA0 + VMOVDQA CC0, BB0 + VMOVDQA DD0, CC0 + VMOVDQA AA1, DD0 + VMOVDQA BB1, AA1 + VMOVDQA CC1, BB1 + VMOVDQA DD1, CC1 + VMOVDQA AA2, DD1 + VMOVDQA BB2, AA2 + JMP openAVX2ShortOpenLoop + +openAVX2ShortTail32: + CMPQ inl, $16 + VMOVDQA A0, A1 + JB openAVX2ShortDone + + SUBQ $16, inl + + // Load for hashing + polyAdd(0*8(inp)) + polyMulAVX2 + + // Load for decryption + VPXOR (inp), A0, T0 + VMOVDQU T0, (oup) + LEAQ (1*16)(inp), inp + LEAQ (1*16)(oup), oup + VPERM2I128 $0x11, AA0, AA0, AA0 + VMOVDQA A0, A1 + +openAVX2ShortDone: + VZEROUPPER + JMP openSSETail16 + +// ---------------------------------------------------------------------------- +// Special optimization for buffers smaller than 321 bytes +openAVX2320: + // For up to 320 bytes of ciphertext and 64 bytes for the poly key, we process six blocks + VMOVDQA AA0, AA1; VMOVDQA BB0, BB1; VMOVDQA CC0, CC1; VPADDD ·avx2IncMask<>(SB), DD0, DD1 + VMOVDQA AA0, AA2; VMOVDQA BB0, BB2; VMOVDQA CC0, CC2; VPADDD ·avx2IncMask<>(SB), DD1, DD2 + VMOVDQA BB0, TT1; VMOVDQA CC0, TT2; VMOVDQA DD0, TT3 + MOVQ $10, itr2 + +openAVX2320InnerCipherLoop: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0) + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0) + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2 + DECQ itr2 + JNE openAVX2320InnerCipherLoop + + VMOVDQA ·chacha20Constants<>(SB), TT0 + VPADDD TT0, AA0, AA0; VPADDD TT0, AA1, AA1; VPADDD TT0, AA2, AA2 + VPADDD TT1, BB0, BB0; VPADDD TT1, BB1, BB1; VPADDD TT1, BB2, BB2 + VPADDD TT2, CC0, CC0; VPADDD TT2, CC1, CC1; VPADDD TT2, CC2, CC2 + VMOVDQA ·avx2IncMask<>(SB), TT0 + VPADDD TT3, DD0, DD0; VPADDD TT0, TT3, TT3 + VPADDD TT3, DD1, DD1; VPADDD TT0, TT3, TT3 + VPADDD TT3, DD2, DD2 + + // Clamp and store poly key + VPERM2I128 $0x02, AA0, BB0, TT0 + VPAND ·polyClampMask<>(SB), TT0, TT0 + VMOVDQA TT0, rsStoreAVX2 + + // Stream for up to 320 bytes + VPERM2I128 $0x13, AA0, BB0, AA0 + VPERM2I128 $0x13, CC0, DD0, BB0 + VPERM2I128 $0x02, AA1, BB1, CC0 + VPERM2I128 $0x02, CC1, DD1, DD0 + VPERM2I128 $0x13, AA1, BB1, AA1 + VPERM2I128 $0x13, CC1, DD1, BB1 + VPERM2I128 $0x02, AA2, BB2, CC1 + VPERM2I128 $0x02, CC2, DD2, DD1 + VPERM2I128 $0x13, AA2, BB2, AA2 + VPERM2I128 $0x13, CC2, DD2, BB2 + JMP openAVX2ShortOpen + +// ---------------------------------------------------------------------------- +// Special optimization for the last 128 bytes of ciphertext +openAVX2Tail128: + // Need to decrypt up to 128 bytes - prepare two blocks + VMOVDQA ·chacha20Constants<>(SB), AA1 + VMOVDQA state1StoreAVX2, BB1 + VMOVDQA state2StoreAVX2, CC1 + VMOVDQA ctr3StoreAVX2, DD1 + VPADDD ·avx2IncMask<>(SB), DD1, DD1 + VMOVDQA DD1, DD0 + + XORQ itr2, itr2 + MOVQ inl, itr1 + ANDQ $-16, itr1 + TESTQ itr1, itr1 + JE openAVX2Tail128LoopB + +openAVX2Tail128LoopA: + // Perform ChaCha rounds, while hashing the remaining input + polyAdd(0(inp)(itr2*1)) + polyMulAVX2 + +openAVX2Tail128LoopB: + ADDQ $16, itr2 + chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + VPALIGNR $4, BB1, BB1, BB1 + VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $12, DD1, DD1, DD1 + chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + VPALIGNR $12, BB1, BB1, BB1 + VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $4, DD1, DD1, DD1 + CMPQ itr2, itr1 + JB openAVX2Tail128LoopA + CMPQ itr2, $160 + JNE openAVX2Tail128LoopB + + VPADDD ·chacha20Constants<>(SB), AA1, AA1 + VPADDD state1StoreAVX2, BB1, BB1 + VPADDD state2StoreAVX2, CC1, CC1 + VPADDD DD0, DD1, DD1 + VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0 + +openAVX2TailLoop: + CMPQ inl, $32 + JB openAVX2Tail + SUBQ $32, inl + + // Load for decryption + VPXOR (inp), AA0, AA0 + VMOVDQU AA0, (oup) + LEAQ (1*32)(inp), inp + LEAQ (1*32)(oup), oup + VMOVDQA BB0, AA0 + VMOVDQA CC0, BB0 + VMOVDQA DD0, CC0 + JMP openAVX2TailLoop + +openAVX2Tail: + CMPQ inl, $16 + VMOVDQA A0, A1 + JB openAVX2TailDone + SUBQ $16, inl + + // Load for decryption + VPXOR (inp), A0, T0 + VMOVDQU T0, (oup) + LEAQ (1*16)(inp), inp + LEAQ (1*16)(oup), oup + VPERM2I128 $0x11, AA0, AA0, AA0 + VMOVDQA A0, A1 + +openAVX2TailDone: + VZEROUPPER + JMP openSSETail16 + +// ---------------------------------------------------------------------------- +// Special optimization for the last 256 bytes of ciphertext +openAVX2Tail256: + // Need to decrypt up to 256 bytes - prepare four blocks + VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1 + VMOVDQA ctr3StoreAVX2, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD1 + VMOVDQA DD0, TT1 + VMOVDQA DD1, TT2 + + // Compute the number of iterations that will hash data + MOVQ inl, tmpStoreAVX2 + MOVQ inl, itr1 + SUBQ $128, itr1 + SHRQ $4, itr1 + MOVQ $10, itr2 + CMPQ itr1, $10 + CMOVQGT itr2, itr1 + MOVQ inp, inl + XORQ itr2, itr2 + +openAVX2Tail256LoopA: + polyAdd(0(inl)) + polyMulAVX2 + LEAQ 16(inl), inl + + // Perform ChaCha rounds, while hashing the remaining input +openAVX2Tail256LoopB: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1 + INCQ itr2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1 + CMPQ itr2, itr1 + JB openAVX2Tail256LoopA + + CMPQ itr2, $10 + JNE openAVX2Tail256LoopB + + MOVQ inl, itr2 + SUBQ inp, inl + MOVQ inl, itr1 + MOVQ tmpStoreAVX2, inl + + // Hash the remainder of data (if any) +openAVX2Tail256Hash: + ADDQ $16, itr1 + CMPQ itr1, inl + JGT openAVX2Tail256HashEnd + polyAdd (0(itr2)) + polyMulAVX2 + LEAQ 16(itr2), itr2 + JMP openAVX2Tail256Hash + +// Store 128 bytes safely, then go to store loop +openAVX2Tail256HashEnd: + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1 + VPADDD TT1, DD0, DD0; VPADDD TT2, DD1, DD1 + VPERM2I128 $0x02, AA0, BB0, AA2; VPERM2I128 $0x02, CC0, DD0, BB2; VPERM2I128 $0x13, AA0, BB0, CC2; VPERM2I128 $0x13, CC0, DD0, DD2 + VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0 + + VPXOR (0*32)(inp), AA2, AA2; VPXOR (1*32)(inp), BB2, BB2; VPXOR (2*32)(inp), CC2, CC2; VPXOR (3*32)(inp), DD2, DD2 + VMOVDQU AA2, (0*32)(oup); VMOVDQU BB2, (1*32)(oup); VMOVDQU CC2, (2*32)(oup); VMOVDQU DD2, (3*32)(oup) + LEAQ (4*32)(inp), inp + LEAQ (4*32)(oup), oup + SUBQ $4*32, inl + + JMP openAVX2TailLoop + +// ---------------------------------------------------------------------------- +// Special optimization for the last 384 bytes of ciphertext +openAVX2Tail384: + // Need to decrypt up to 384 bytes - prepare six blocks + VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2 + VMOVDQA ctr3StoreAVX2, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD1 + VPADDD ·avx2IncMask<>(SB), DD1, DD2 + VMOVDQA DD0, ctr0StoreAVX2 + VMOVDQA DD1, ctr1StoreAVX2 + VMOVDQA DD2, ctr2StoreAVX2 + + // Compute the number of iterations that will hash two blocks of data + MOVQ inl, tmpStoreAVX2 + MOVQ inl, itr1 + SUBQ $256, itr1 + SHRQ $4, itr1 + ADDQ $6, itr1 + MOVQ $10, itr2 + CMPQ itr1, $10 + CMOVQGT itr2, itr1 + MOVQ inp, inl + XORQ itr2, itr2 + + // Perform ChaCha rounds, while hashing the remaining input +openAVX2Tail384LoopB: + polyAdd(0(inl)) + polyMulAVX2 + LEAQ 16(inl), inl + +openAVX2Tail384LoopA: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0) + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2 + polyAdd(0(inl)) + polyMulAVX2 + LEAQ 16(inl), inl + INCQ itr2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0) + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2 + + CMPQ itr2, itr1 + JB openAVX2Tail384LoopB + + CMPQ itr2, $10 + JNE openAVX2Tail384LoopA + + MOVQ inl, itr2 + SUBQ inp, inl + MOVQ inl, itr1 + MOVQ tmpStoreAVX2, inl + +openAVX2Tail384Hash: + ADDQ $16, itr1 + CMPQ itr1, inl + JGT openAVX2Tail384HashEnd + polyAdd(0(itr2)) + polyMulAVX2 + LEAQ 16(itr2), itr2 + JMP openAVX2Tail384Hash + +// Store 256 bytes safely, then go to store loop +openAVX2Tail384HashEnd: + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2 + VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2 + VPERM2I128 $0x02, AA0, BB0, TT0; VPERM2I128 $0x02, CC0, DD0, TT1; VPERM2I128 $0x13, AA0, BB0, TT2; VPERM2I128 $0x13, CC0, DD0, TT3 + VPXOR (0*32)(inp), TT0, TT0; VPXOR (1*32)(inp), TT1, TT1; VPXOR (2*32)(inp), TT2, TT2; VPXOR (3*32)(inp), TT3, TT3 + VMOVDQU TT0, (0*32)(oup); VMOVDQU TT1, (1*32)(oup); VMOVDQU TT2, (2*32)(oup); VMOVDQU TT3, (3*32)(oup) + VPERM2I128 $0x02, AA1, BB1, TT0; VPERM2I128 $0x02, CC1, DD1, TT1; VPERM2I128 $0x13, AA1, BB1, TT2; VPERM2I128 $0x13, CC1, DD1, TT3 + VPXOR (4*32)(inp), TT0, TT0; VPXOR (5*32)(inp), TT1, TT1; VPXOR (6*32)(inp), TT2, TT2; VPXOR (7*32)(inp), TT3, TT3 + VMOVDQU TT0, (4*32)(oup); VMOVDQU TT1, (5*32)(oup); VMOVDQU TT2, (6*32)(oup); VMOVDQU TT3, (7*32)(oup) + VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0 + LEAQ (8*32)(inp), inp + LEAQ (8*32)(oup), oup + SUBQ $8*32, inl + JMP openAVX2TailLoop + +// ---------------------------------------------------------------------------- +// Special optimization for the last 512 bytes of ciphertext +openAVX2Tail512: + VMOVDQU ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3 + VMOVDQA ctr3StoreAVX2, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3 + VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2 + XORQ itr1, itr1 + MOVQ inp, itr2 + +openAVX2Tail512LoopB: + polyAdd(0(itr2)) + polyMulAVX2 + LEAQ (2*8)(itr2), itr2 + +openAVX2Tail512LoopA: + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyAdd(0*8(itr2)) + polyMulAVX2 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $4, BB3, BB3, BB3 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2; VPALIGNR $12, DD3, DD3, DD3 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + polyAdd(2*8(itr2)) + polyMulAVX2 + LEAQ (4*8)(itr2), itr2 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $12, BB3, BB3, BB3 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2; VPALIGNR $4, DD3, DD3, DD3 + INCQ itr1 + CMPQ itr1, $4 + JLT openAVX2Tail512LoopB + + CMPQ itr1, $10 + JNE openAVX2Tail512LoopA + + MOVQ inl, itr1 + SUBQ $384, itr1 + ANDQ $-16, itr1 + +openAVX2Tail512HashLoop: + TESTQ itr1, itr1 + JE openAVX2Tail512HashEnd + polyAdd(0(itr2)) + polyMulAVX2 + LEAQ 16(itr2), itr2 + SUBQ $16, itr1 + JMP openAVX2Tail512HashLoop + +openAVX2Tail512HashEnd: + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3 + VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3 + VMOVDQA CC3, tmpStoreAVX2 + VPERM2I128 $0x02, AA0, BB0, CC3; VPERM2I128 $0x13, AA0, BB0, BB0; VPERM2I128 $0x02, CC0, DD0, AA0; VPERM2I128 $0x13, CC0, DD0, CC0 + VPXOR (0*32)(inp), CC3, CC3; VPXOR (1*32)(inp), AA0, AA0; VPXOR (2*32)(inp), BB0, BB0; VPXOR (3*32)(inp), CC0, CC0 + VMOVDQU CC3, (0*32)(oup); VMOVDQU AA0, (1*32)(oup); VMOVDQU BB0, (2*32)(oup); VMOVDQU CC0, (3*32)(oup) + VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0 + VPXOR (4*32)(inp), AA0, AA0; VPXOR (5*32)(inp), BB0, BB0; VPXOR (6*32)(inp), CC0, CC0; VPXOR (7*32)(inp), DD0, DD0 + VMOVDQU AA0, (4*32)(oup); VMOVDQU BB0, (5*32)(oup); VMOVDQU CC0, (6*32)(oup); VMOVDQU DD0, (7*32)(oup) + VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0 + VPXOR (8*32)(inp), AA0, AA0; VPXOR (9*32)(inp), BB0, BB0; VPXOR (10*32)(inp), CC0, CC0; VPXOR (11*32)(inp), DD0, DD0 + VMOVDQU AA0, (8*32)(oup); VMOVDQU BB0, (9*32)(oup); VMOVDQU CC0, (10*32)(oup); VMOVDQU DD0, (11*32)(oup) + VPERM2I128 $0x02, AA3, BB3, AA0; VPERM2I128 $0x02, tmpStoreAVX2, DD3, BB0; VPERM2I128 $0x13, AA3, BB3, CC0; VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0 + + LEAQ (12*32)(inp), inp + LEAQ (12*32)(oup), oup + SUBQ $12*32, inl + + JMP openAVX2TailLoop + +// ---------------------------------------------------------------------------- +// ---------------------------------------------------------------------------- +// func chacha20Poly1305Seal(dst, key, src, ad []byte) +TEXT ·chacha20Poly1305Seal(SB), 0, $288-96 + // For aligned stack access + MOVQ SP, BP + ADDQ $32, BP + ANDQ $-32, BP + MOVQ dst+0(FP), oup + MOVQ key+24(FP), keyp + MOVQ src+48(FP), inp + MOVQ src_len+56(FP), inl + MOVQ ad+72(FP), adp + + CMPB ·useAVX2(SB), $1 + JE chacha20Poly1305Seal_AVX2 + + // Special optimization, for very short buffers + CMPQ inl, $128 + JBE sealSSE128 // About 15% faster + + // In the seal case - prepare the poly key + 3 blocks of stream in the first iteration + MOVOU ·chacha20Constants<>(SB), A0 + MOVOU (1*16)(keyp), B0 + MOVOU (2*16)(keyp), C0 + MOVOU (3*16)(keyp), D0 + + // Store state on stack for future use + MOVO B0, state1Store + MOVO C0, state2Store + + // Load state, increment counter blocks + MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1 + MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2 + MOVO A2, A3; MOVO B2, B3; MOVO C2, C3; MOVO D2, D3; PADDL ·sseIncMask<>(SB), D3 + + // Store counters + MOVO D0, ctr0Store; MOVO D1, ctr1Store; MOVO D2, ctr2Store; MOVO D3, ctr3Store + MOVQ $10, itr2 + +sealSSEIntroLoop: + MOVO C3, tmpStore + chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3) + MOVO tmpStore, C3 + MOVO C1, tmpStore + chachaQR(A3, B3, C3, D3, C1) + MOVO tmpStore, C1 + shiftB0Left; shiftB1Left; shiftB2Left; shiftB3Left + shiftC0Left; shiftC1Left; shiftC2Left; shiftC3Left + shiftD0Left; shiftD1Left; shiftD2Left; shiftD3Left + + MOVO C3, tmpStore + chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3) + MOVO tmpStore, C3 + MOVO C1, tmpStore + chachaQR(A3, B3, C3, D3, C1) + MOVO tmpStore, C1 + shiftB0Right; shiftB1Right; shiftB2Right; shiftB3Right + shiftC0Right; shiftC1Right; shiftC2Right; shiftC3Right + shiftD0Right; shiftD1Right; shiftD2Right; shiftD3Right + DECQ itr2 + JNE sealSSEIntroLoop + + // Add in the state + PADDD ·chacha20Constants<>(SB), A0; PADDD ·chacha20Constants<>(SB), A1; PADDD ·chacha20Constants<>(SB), A2; PADDD ·chacha20Constants<>(SB), A3 + PADDD state1Store, B0; PADDD state1Store, B1; PADDD state1Store, B2; PADDD state1Store, B3 + PADDD state2Store, C1; PADDD state2Store, C2; PADDD state2Store, C3 + PADDD ctr1Store, D1; PADDD ctr2Store, D2; PADDD ctr3Store, D3 + + // Clamp and store the key + PAND ·polyClampMask<>(SB), A0 + MOVO A0, rStore + MOVO B0, sStore + + // Hash AAD + MOVQ ad_len+80(FP), itr2 + CALL polyHashADInternal<>(SB) + + MOVOU (0*16)(inp), A0; MOVOU (1*16)(inp), B0; MOVOU (2*16)(inp), C0; MOVOU (3*16)(inp), D0 + PXOR A0, A1; PXOR B0, B1; PXOR C0, C1; PXOR D0, D1 + MOVOU A1, (0*16)(oup); MOVOU B1, (1*16)(oup); MOVOU C1, (2*16)(oup); MOVOU D1, (3*16)(oup) + MOVOU (4*16)(inp), A0; MOVOU (5*16)(inp), B0; MOVOU (6*16)(inp), C0; MOVOU (7*16)(inp), D0 + PXOR A0, A2; PXOR B0, B2; PXOR C0, C2; PXOR D0, D2 + MOVOU A2, (4*16)(oup); MOVOU B2, (5*16)(oup); MOVOU C2, (6*16)(oup); MOVOU D2, (7*16)(oup) + + MOVQ $128, itr1 + SUBQ $128, inl + LEAQ 128(inp), inp + + MOVO A3, A1; MOVO B3, B1; MOVO C3, C1; MOVO D3, D1 + + CMPQ inl, $64 + JBE sealSSE128SealHash + + MOVOU (0*16)(inp), A0; MOVOU (1*16)(inp), B0; MOVOU (2*16)(inp), C0; MOVOU (3*16)(inp), D0 + PXOR A0, A3; PXOR B0, B3; PXOR C0, C3; PXOR D0, D3 + MOVOU A3, (8*16)(oup); MOVOU B3, (9*16)(oup); MOVOU C3, (10*16)(oup); MOVOU D3, (11*16)(oup) + + ADDQ $64, itr1 + SUBQ $64, inl + LEAQ 64(inp), inp + + MOVQ $2, itr1 + MOVQ $8, itr2 + + CMPQ inl, $64 + JBE sealSSETail64 + CMPQ inl, $128 + JBE sealSSETail128 + CMPQ inl, $192 + JBE sealSSETail192 + +sealSSEMainLoop: + // Load state, increment counter blocks + MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0 + MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1 + MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2 + MOVO A2, A3; MOVO B2, B3; MOVO C2, C3; MOVO D2, D3; PADDL ·sseIncMask<>(SB), D3 + + // Store counters + MOVO D0, ctr0Store; MOVO D1, ctr1Store; MOVO D2, ctr2Store; MOVO D3, ctr3Store + +sealSSEInnerLoop: + MOVO C3, tmpStore + chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3) + MOVO tmpStore, C3 + MOVO C1, tmpStore + chachaQR(A3, B3, C3, D3, C1) + MOVO tmpStore, C1 + polyAdd(0(oup)) + shiftB0Left; shiftB1Left; shiftB2Left; shiftB3Left + shiftC0Left; shiftC1Left; shiftC2Left; shiftC3Left + shiftD0Left; shiftD1Left; shiftD2Left; shiftD3Left + polyMulStage1 + polyMulStage2 + LEAQ (2*8)(oup), oup + MOVO C3, tmpStore + chachaQR(A0, B0, C0, D0, C3); chachaQR(A1, B1, C1, D1, C3); chachaQR(A2, B2, C2, D2, C3) + MOVO tmpStore, C3 + MOVO C1, tmpStore + polyMulStage3 + chachaQR(A3, B3, C3, D3, C1) + MOVO tmpStore, C1 + polyMulReduceStage + shiftB0Right; shiftB1Right; shiftB2Right; shiftB3Right + shiftC0Right; shiftC1Right; shiftC2Right; shiftC3Right + shiftD0Right; shiftD1Right; shiftD2Right; shiftD3Right + DECQ itr2 + JGE sealSSEInnerLoop + polyAdd(0(oup)) + polyMul + LEAQ (2*8)(oup), oup + DECQ itr1 + JG sealSSEInnerLoop + + // Add in the state + PADDD ·chacha20Constants<>(SB), A0; PADDD ·chacha20Constants<>(SB), A1; PADDD ·chacha20Constants<>(SB), A2; PADDD ·chacha20Constants<>(SB), A3 + PADDD state1Store, B0; PADDD state1Store, B1; PADDD state1Store, B2; PADDD state1Store, B3 + PADDD state2Store, C0; PADDD state2Store, C1; PADDD state2Store, C2; PADDD state2Store, C3 + PADDD ctr0Store, D0; PADDD ctr1Store, D1; PADDD ctr2Store, D2; PADDD ctr3Store, D3 + MOVO D3, tmpStore + + // Load - xor - store + MOVOU (0*16)(inp), D3; PXOR D3, A0 + MOVOU (1*16)(inp), D3; PXOR D3, B0 + MOVOU (2*16)(inp), D3; PXOR D3, C0 + MOVOU (3*16)(inp), D3; PXOR D3, D0 + MOVOU A0, (0*16)(oup) + MOVOU B0, (1*16)(oup) + MOVOU C0, (2*16)(oup) + MOVOU D0, (3*16)(oup) + MOVO tmpStore, D3 + + MOVOU (4*16)(inp), A0; MOVOU (5*16)(inp), B0; MOVOU (6*16)(inp), C0; MOVOU (7*16)(inp), D0 + PXOR A0, A1; PXOR B0, B1; PXOR C0, C1; PXOR D0, D1 + MOVOU A1, (4*16)(oup); MOVOU B1, (5*16)(oup); MOVOU C1, (6*16)(oup); MOVOU D1, (7*16)(oup) + MOVOU (8*16)(inp), A0; MOVOU (9*16)(inp), B0; MOVOU (10*16)(inp), C0; MOVOU (11*16)(inp), D0 + PXOR A0, A2; PXOR B0, B2; PXOR C0, C2; PXOR D0, D2 + MOVOU A2, (8*16)(oup); MOVOU B2, (9*16)(oup); MOVOU C2, (10*16)(oup); MOVOU D2, (11*16)(oup) + ADDQ $192, inp + MOVQ $192, itr1 + SUBQ $192, inl + MOVO A3, A1 + MOVO B3, B1 + MOVO C3, C1 + MOVO D3, D1 + CMPQ inl, $64 + JBE sealSSE128SealHash + MOVOU (0*16)(inp), A0; MOVOU (1*16)(inp), B0; MOVOU (2*16)(inp), C0; MOVOU (3*16)(inp), D0 + PXOR A0, A3; PXOR B0, B3; PXOR C0, C3; PXOR D0, D3 + MOVOU A3, (12*16)(oup); MOVOU B3, (13*16)(oup); MOVOU C3, (14*16)(oup); MOVOU D3, (15*16)(oup) + LEAQ 64(inp), inp + SUBQ $64, inl + MOVQ $6, itr1 + MOVQ $4, itr2 + CMPQ inl, $192 + JG sealSSEMainLoop + + MOVQ inl, itr1 + TESTQ inl, inl + JE sealSSE128SealHash + MOVQ $6, itr1 + CMPQ inl, $64 + JBE sealSSETail64 + CMPQ inl, $128 + JBE sealSSETail128 + JMP sealSSETail192 + +// ---------------------------------------------------------------------------- +// Special optimization for the last 64 bytes of plaintext +sealSSETail64: + // Need to encrypt up to 64 bytes - prepare single block, hash 192 or 256 bytes + MOVO ·chacha20Constants<>(SB), A1 + MOVO state1Store, B1 + MOVO state2Store, C1 + MOVO ctr3Store, D1 + PADDL ·sseIncMask<>(SB), D1 + MOVO D1, ctr0Store + +sealSSETail64LoopA: + // Perform ChaCha rounds, while hashing the previously encrypted ciphertext + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + +sealSSETail64LoopB: + chachaQR(A1, B1, C1, D1, T1) + shiftB1Left; shiftC1Left; shiftD1Left + chachaQR(A1, B1, C1, D1, T1) + shiftB1Right; shiftC1Right; shiftD1Right + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + + DECQ itr1 + JG sealSSETail64LoopA + + DECQ itr2 + JGE sealSSETail64LoopB + PADDL ·chacha20Constants<>(SB), A1 + PADDL state1Store, B1 + PADDL state2Store, C1 + PADDL ctr0Store, D1 + + JMP sealSSE128Seal + +// ---------------------------------------------------------------------------- +// Special optimization for the last 128 bytes of plaintext +sealSSETail128: + // Need to encrypt up to 128 bytes - prepare two blocks, hash 192 or 256 bytes + MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr0Store + MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr1Store + +sealSSETail128LoopA: + // Perform ChaCha rounds, while hashing the previously encrypted ciphertext + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + +sealSSETail128LoopB: + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0) + shiftB0Left; shiftC0Left; shiftD0Left + shiftB1Left; shiftC1Left; shiftD1Left + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0) + shiftB0Right; shiftC0Right; shiftD0Right + shiftB1Right; shiftC1Right; shiftD1Right + + DECQ itr1 + JG sealSSETail128LoopA + + DECQ itr2 + JGE sealSSETail128LoopB + + PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1 + PADDL state1Store, B0; PADDL state1Store, B1 + PADDL state2Store, C0; PADDL state2Store, C1 + PADDL ctr0Store, D0; PADDL ctr1Store, D1 + + MOVOU (0*16)(inp), T0; MOVOU (1*16)(inp), T1; MOVOU (2*16)(inp), T2; MOVOU (3*16)(inp), T3 + PXOR T0, A0; PXOR T1, B0; PXOR T2, C0; PXOR T3, D0 + MOVOU A0, (0*16)(oup); MOVOU B0, (1*16)(oup); MOVOU C0, (2*16)(oup); MOVOU D0, (3*16)(oup) + + MOVQ $64, itr1 + LEAQ 64(inp), inp + SUBQ $64, inl + + JMP sealSSE128SealHash + +// ---------------------------------------------------------------------------- +// Special optimization for the last 192 bytes of plaintext +sealSSETail192: + // Need to encrypt up to 192 bytes - prepare three blocks, hash 192 or 256 bytes + MOVO ·chacha20Constants<>(SB), A0; MOVO state1Store, B0; MOVO state2Store, C0; MOVO ctr3Store, D0; PADDL ·sseIncMask<>(SB), D0; MOVO D0, ctr0Store + MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1; MOVO D1, ctr1Store + MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2; MOVO D2, ctr2Store + +sealSSETail192LoopA: + // Perform ChaCha rounds, while hashing the previously encrypted ciphertext + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + +sealSSETail192LoopB: + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0) + shiftB0Left; shiftC0Left; shiftD0Left + shiftB1Left; shiftC1Left; shiftD1Left + shiftB2Left; shiftC2Left; shiftD2Left + + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0) + shiftB0Right; shiftC0Right; shiftD0Right + shiftB1Right; shiftC1Right; shiftD1Right + shiftB2Right; shiftC2Right; shiftD2Right + + DECQ itr1 + JG sealSSETail192LoopA + + DECQ itr2 + JGE sealSSETail192LoopB + + PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1; PADDL ·chacha20Constants<>(SB), A2 + PADDL state1Store, B0; PADDL state1Store, B1; PADDL state1Store, B2 + PADDL state2Store, C0; PADDL state2Store, C1; PADDL state2Store, C2 + PADDL ctr0Store, D0; PADDL ctr1Store, D1; PADDL ctr2Store, D2 + + MOVOU (0*16)(inp), T0; MOVOU (1*16)(inp), T1; MOVOU (2*16)(inp), T2; MOVOU (3*16)(inp), T3 + PXOR T0, A0; PXOR T1, B0; PXOR T2, C0; PXOR T3, D0 + MOVOU A0, (0*16)(oup); MOVOU B0, (1*16)(oup); MOVOU C0, (2*16)(oup); MOVOU D0, (3*16)(oup) + MOVOU (4*16)(inp), T0; MOVOU (5*16)(inp), T1; MOVOU (6*16)(inp), T2; MOVOU (7*16)(inp), T3 + PXOR T0, A1; PXOR T1, B1; PXOR T2, C1; PXOR T3, D1 + MOVOU A1, (4*16)(oup); MOVOU B1, (5*16)(oup); MOVOU C1, (6*16)(oup); MOVOU D1, (7*16)(oup) + + MOVO A2, A1 + MOVO B2, B1 + MOVO C2, C1 + MOVO D2, D1 + MOVQ $128, itr1 + LEAQ 128(inp), inp + SUBQ $128, inl + + JMP sealSSE128SealHash + +// ---------------------------------------------------------------------------- +// Special seal optimization for buffers smaller than 129 bytes +sealSSE128: + // For up to 128 bytes of ciphertext and 64 bytes for the poly key, we require to process three blocks + MOVOU ·chacha20Constants<>(SB), A0; MOVOU (1*16)(keyp), B0; MOVOU (2*16)(keyp), C0; MOVOU (3*16)(keyp), D0 + MOVO A0, A1; MOVO B0, B1; MOVO C0, C1; MOVO D0, D1; PADDL ·sseIncMask<>(SB), D1 + MOVO A1, A2; MOVO B1, B2; MOVO C1, C2; MOVO D1, D2; PADDL ·sseIncMask<>(SB), D2 + MOVO B0, T1; MOVO C0, T2; MOVO D1, T3 + MOVQ $10, itr2 + +sealSSE128InnerCipherLoop: + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0) + shiftB0Left; shiftB1Left; shiftB2Left + shiftC0Left; shiftC1Left; shiftC2Left + shiftD0Left; shiftD1Left; shiftD2Left + chachaQR(A0, B0, C0, D0, T0); chachaQR(A1, B1, C1, D1, T0); chachaQR(A2, B2, C2, D2, T0) + shiftB0Right; shiftB1Right; shiftB2Right + shiftC0Right; shiftC1Right; shiftC2Right + shiftD0Right; shiftD1Right; shiftD2Right + DECQ itr2 + JNE sealSSE128InnerCipherLoop + + // A0|B0 hold the Poly1305 32-byte key, C0,D0 can be discarded + PADDL ·chacha20Constants<>(SB), A0; PADDL ·chacha20Constants<>(SB), A1; PADDL ·chacha20Constants<>(SB), A2 + PADDL T1, B0; PADDL T1, B1; PADDL T1, B2 + PADDL T2, C1; PADDL T2, C2 + PADDL T3, D1; PADDL ·sseIncMask<>(SB), T3; PADDL T3, D2 + PAND ·polyClampMask<>(SB), A0 + MOVOU A0, rStore + MOVOU B0, sStore + + // Hash + MOVQ ad_len+80(FP), itr2 + CALL polyHashADInternal<>(SB) + XORQ itr1, itr1 + +sealSSE128SealHash: + // itr1 holds the number of bytes encrypted but not yet hashed + CMPQ itr1, $16 + JB sealSSE128Seal + polyAdd(0(oup)) + polyMul + + SUBQ $16, itr1 + ADDQ $16, oup + + JMP sealSSE128SealHash + +sealSSE128Seal: + CMPQ inl, $16 + JB sealSSETail + SUBQ $16, inl + + // Load for decryption + MOVOU (inp), T0 + PXOR T0, A1 + MOVOU A1, (oup) + LEAQ (1*16)(inp), inp + LEAQ (1*16)(oup), oup + + // Extract for hashing + MOVQ A1, t0 + PSRLDQ $8, A1 + MOVQ A1, t1 + ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $1, acc2 + polyMul + + // Shift the stream "left" + MOVO B1, A1 + MOVO C1, B1 + MOVO D1, C1 + MOVO A2, D1 + MOVO B2, A2 + MOVO C2, B2 + MOVO D2, C2 + JMP sealSSE128Seal + +sealSSETail: + TESTQ inl, inl + JE sealSSEFinalize + + // We can only load the PT one byte at a time to avoid read after end of buffer + MOVQ inl, itr2 + SHLQ $4, itr2 + LEAQ ·andMask<>(SB), t0 + MOVQ inl, itr1 + LEAQ -1(inp)(inl*1), inp + XORQ t2, t2 + XORQ t3, t3 + XORQ AX, AX + +sealSSETailLoadLoop: + SHLQ $8, t2, t3 + SHLQ $8, t2 + MOVB (inp), AX + XORQ AX, t2 + LEAQ -1(inp), inp + DECQ itr1 + JNE sealSSETailLoadLoop + MOVQ t2, 0+tmpStore + MOVQ t3, 8+tmpStore + PXOR 0+tmpStore, A1 + MOVOU A1, (oup) + MOVOU -16(t0)(itr2*1), T0 + PAND T0, A1 + MOVQ A1, t0 + PSRLDQ $8, A1 + MOVQ A1, t1 + ADDQ t0, acc0; ADCQ t1, acc1; ADCQ $1, acc2 + polyMul + + ADDQ inl, oup + +sealSSEFinalize: + // Hash in the buffer lengths + ADDQ ad_len+80(FP), acc0 + ADCQ src_len+56(FP), acc1 + ADCQ $1, acc2 + polyMul + + // Final reduce + MOVQ acc0, t0 + MOVQ acc1, t1 + MOVQ acc2, t2 + SUBQ $-5, acc0 + SBBQ $-1, acc1 + SBBQ $3, acc2 + CMOVQCS t0, acc0 + CMOVQCS t1, acc1 + CMOVQCS t2, acc2 + + // Add in the "s" part of the key + ADDQ 0+sStore, acc0 + ADCQ 8+sStore, acc1 + + // Finally store the tag at the end of the message + MOVQ acc0, (0*8)(oup) + MOVQ acc1, (1*8)(oup) + RET + +// ---------------------------------------------------------------------------- +// ------------------------- AVX2 Code ---------------------------------------- +chacha20Poly1305Seal_AVX2: + VZEROUPPER + VMOVDQU ·chacha20Constants<>(SB), AA0 + BYTE $0xc4; BYTE $0x42; BYTE $0x7d; BYTE $0x5a; BYTE $0x70; BYTE $0x10 // broadcasti128 16(r8), ymm14 + BYTE $0xc4; BYTE $0x42; BYTE $0x7d; BYTE $0x5a; BYTE $0x60; BYTE $0x20 // broadcasti128 32(r8), ymm12 + BYTE $0xc4; BYTE $0xc2; BYTE $0x7d; BYTE $0x5a; BYTE $0x60; BYTE $0x30 // broadcasti128 48(r8), ymm4 + VPADDD ·avx2InitMask<>(SB), DD0, DD0 + + // Special optimizations, for very short buffers + CMPQ inl, $192 + JBE seal192AVX2 // 33% faster + CMPQ inl, $320 + JBE seal320AVX2 // 17% faster + + // For the general key prepare the key first - as a byproduct we have 64 bytes of cipher stream + VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3 + VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3; VMOVDQA BB0, state1StoreAVX2 + VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3; VMOVDQA CC0, state2StoreAVX2 + VPADDD ·avx2IncMask<>(SB), DD0, DD1; VMOVDQA DD0, ctr0StoreAVX2 + VPADDD ·avx2IncMask<>(SB), DD1, DD2; VMOVDQA DD1, ctr1StoreAVX2 + VPADDD ·avx2IncMask<>(SB), DD2, DD3; VMOVDQA DD2, ctr2StoreAVX2 + VMOVDQA DD3, ctr3StoreAVX2 + MOVQ $10, itr2 + +sealAVX2IntroLoop: + VMOVDQA CC3, tmpStoreAVX2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, CC3); chachaQR_AVX2(AA1, BB1, CC1, DD1, CC3); chachaQR_AVX2(AA2, BB2, CC2, DD2, CC3) + VMOVDQA tmpStoreAVX2, CC3 + VMOVDQA CC1, tmpStoreAVX2 + chachaQR_AVX2(AA3, BB3, CC3, DD3, CC1) + VMOVDQA tmpStoreAVX2, CC1 + + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $12, DD0, DD0, DD0 + VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $12, DD1, DD1, DD1 + VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $12, DD2, DD2, DD2 + VPALIGNR $4, BB3, BB3, BB3; VPALIGNR $8, CC3, CC3, CC3; VPALIGNR $12, DD3, DD3, DD3 + + VMOVDQA CC3, tmpStoreAVX2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, CC3); chachaQR_AVX2(AA1, BB1, CC1, DD1, CC3); chachaQR_AVX2(AA2, BB2, CC2, DD2, CC3) + VMOVDQA tmpStoreAVX2, CC3 + VMOVDQA CC1, tmpStoreAVX2 + chachaQR_AVX2(AA3, BB3, CC3, DD3, CC1) + VMOVDQA tmpStoreAVX2, CC1 + + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $4, DD0, DD0, DD0 + VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $4, DD1, DD1, DD1 + VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $4, DD2, DD2, DD2 + VPALIGNR $12, BB3, BB3, BB3; VPALIGNR $8, CC3, CC3, CC3; VPALIGNR $4, DD3, DD3, DD3 + DECQ itr2 + JNE sealAVX2IntroLoop + + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3 + VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3 + + VPERM2I128 $0x13, CC0, DD0, CC0 // Stream bytes 96 - 127 + VPERM2I128 $0x02, AA0, BB0, DD0 // The Poly1305 key + VPERM2I128 $0x13, AA0, BB0, AA0 // Stream bytes 64 - 95 + + // Clamp and store poly key + VPAND ·polyClampMask<>(SB), DD0, DD0 + VMOVDQA DD0, rsStoreAVX2 + + // Hash AD + MOVQ ad_len+80(FP), itr2 + CALL polyHashADInternal<>(SB) + + // Can store at least 320 bytes + VPXOR (0*32)(inp), AA0, AA0 + VPXOR (1*32)(inp), CC0, CC0 + VMOVDQU AA0, (0*32)(oup) + VMOVDQU CC0, (1*32)(oup) + + VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0 + VPXOR (2*32)(inp), AA0, AA0; VPXOR (3*32)(inp), BB0, BB0; VPXOR (4*32)(inp), CC0, CC0; VPXOR (5*32)(inp), DD0, DD0 + VMOVDQU AA0, (2*32)(oup); VMOVDQU BB0, (3*32)(oup); VMOVDQU CC0, (4*32)(oup); VMOVDQU DD0, (5*32)(oup) + VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0 + VPXOR (6*32)(inp), AA0, AA0; VPXOR (7*32)(inp), BB0, BB0; VPXOR (8*32)(inp), CC0, CC0; VPXOR (9*32)(inp), DD0, DD0 + VMOVDQU AA0, (6*32)(oup); VMOVDQU BB0, (7*32)(oup); VMOVDQU CC0, (8*32)(oup); VMOVDQU DD0, (9*32)(oup) + + MOVQ $320, itr1 + SUBQ $320, inl + LEAQ 320(inp), inp + + VPERM2I128 $0x02, AA3, BB3, AA0; VPERM2I128 $0x02, CC3, DD3, BB0; VPERM2I128 $0x13, AA3, BB3, CC0; VPERM2I128 $0x13, CC3, DD3, DD0 + CMPQ inl, $128 + JBE sealAVX2SealHash + + VPXOR (0*32)(inp), AA0, AA0; VPXOR (1*32)(inp), BB0, BB0; VPXOR (2*32)(inp), CC0, CC0; VPXOR (3*32)(inp), DD0, DD0 + VMOVDQU AA0, (10*32)(oup); VMOVDQU BB0, (11*32)(oup); VMOVDQU CC0, (12*32)(oup); VMOVDQU DD0, (13*32)(oup) + SUBQ $128, inl + LEAQ 128(inp), inp + + MOVQ $8, itr1 + MOVQ $2, itr2 + + CMPQ inl, $128 + JBE sealAVX2Tail128 + CMPQ inl, $256 + JBE sealAVX2Tail256 + CMPQ inl, $384 + JBE sealAVX2Tail384 + CMPQ inl, $512 + JBE sealAVX2Tail512 + + // We have 448 bytes to hash, but main loop hashes 512 bytes at a time - perform some rounds, before the main loop + VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3 + VMOVDQA ctr3StoreAVX2, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3 + VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2 + + VMOVDQA CC3, tmpStoreAVX2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, CC3); chachaQR_AVX2(AA1, BB1, CC1, DD1, CC3); chachaQR_AVX2(AA2, BB2, CC2, DD2, CC3) + VMOVDQA tmpStoreAVX2, CC3 + VMOVDQA CC1, tmpStoreAVX2 + chachaQR_AVX2(AA3, BB3, CC3, DD3, CC1) + VMOVDQA tmpStoreAVX2, CC1 + + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $12, DD0, DD0, DD0 + VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $12, DD1, DD1, DD1 + VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $12, DD2, DD2, DD2 + VPALIGNR $4, BB3, BB3, BB3; VPALIGNR $8, CC3, CC3, CC3; VPALIGNR $12, DD3, DD3, DD3 + + VMOVDQA CC3, tmpStoreAVX2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, CC3); chachaQR_AVX2(AA1, BB1, CC1, DD1, CC3); chachaQR_AVX2(AA2, BB2, CC2, DD2, CC3) + VMOVDQA tmpStoreAVX2, CC3 + VMOVDQA CC1, tmpStoreAVX2 + chachaQR_AVX2(AA3, BB3, CC3, DD3, CC1) + VMOVDQA tmpStoreAVX2, CC1 + + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $4, DD0, DD0, DD0 + VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $4, DD1, DD1, DD1 + VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $4, DD2, DD2, DD2 + VPALIGNR $12, BB3, BB3, BB3; VPALIGNR $8, CC3, CC3, CC3; VPALIGNR $4, DD3, DD3, DD3 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + + SUBQ $16, oup // Adjust the pointer + MOVQ $9, itr1 + JMP sealAVX2InternalLoopStart + +sealAVX2MainLoop: + // Load state, increment counter blocks, store the incremented counters + VMOVDQU ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3 + VMOVDQA ctr3StoreAVX2, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3 + VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2 + MOVQ $10, itr1 + +sealAVX2InternalLoop: + polyAdd(0*8(oup)) + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + polyMulStage1_AVX2 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + polyMulStage2_AVX2 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + polyMulStage3_AVX2 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyMulReduceStage + +sealAVX2InternalLoopStart: + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3 + polyAdd(2*8(oup)) + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + polyMulStage1_AVX2 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyMulStage2_AVX2 + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $4, BB3, BB3, BB3 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2; VPALIGNR $12, DD3, DD3, DD3 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + polyMulStage3_AVX2 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + polyMulReduceStage + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + polyAdd(4*8(oup)) + LEAQ (6*8)(oup), oup + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyMulStage1_AVX2 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + polyMulStage2_AVX2 + VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + polyMulStage3_AVX2 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyMulReduceStage + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $12, BB3, BB3, BB3 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2; VPALIGNR $4, DD3, DD3, DD3 + DECQ itr1 + JNE sealAVX2InternalLoop + + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3 + VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3 + VMOVDQA CC3, tmpStoreAVX2 + + // We only hashed 480 of the 512 bytes available - hash the remaining 32 here + polyAdd(0*8(oup)) + polyMulAVX2 + LEAQ (4*8)(oup), oup + VPERM2I128 $0x02, AA0, BB0, CC3; VPERM2I128 $0x13, AA0, BB0, BB0; VPERM2I128 $0x02, CC0, DD0, AA0; VPERM2I128 $0x13, CC0, DD0, CC0 + VPXOR (0*32)(inp), CC3, CC3; VPXOR (1*32)(inp), AA0, AA0; VPXOR (2*32)(inp), BB0, BB0; VPXOR (3*32)(inp), CC0, CC0 + VMOVDQU CC3, (0*32)(oup); VMOVDQU AA0, (1*32)(oup); VMOVDQU BB0, (2*32)(oup); VMOVDQU CC0, (3*32)(oup) + VPERM2I128 $0x02, AA1, BB1, AA0; VPERM2I128 $0x02, CC1, DD1, BB0; VPERM2I128 $0x13, AA1, BB1, CC0; VPERM2I128 $0x13, CC1, DD1, DD0 + VPXOR (4*32)(inp), AA0, AA0; VPXOR (5*32)(inp), BB0, BB0; VPXOR (6*32)(inp), CC0, CC0; VPXOR (7*32)(inp), DD0, DD0 + VMOVDQU AA0, (4*32)(oup); VMOVDQU BB0, (5*32)(oup); VMOVDQU CC0, (6*32)(oup); VMOVDQU DD0, (7*32)(oup) + + // and here + polyAdd(-2*8(oup)) + polyMulAVX2 + VPERM2I128 $0x02, AA2, BB2, AA0; VPERM2I128 $0x02, CC2, DD2, BB0; VPERM2I128 $0x13, AA2, BB2, CC0; VPERM2I128 $0x13, CC2, DD2, DD0 + VPXOR (8*32)(inp), AA0, AA0; VPXOR (9*32)(inp), BB0, BB0; VPXOR (10*32)(inp), CC0, CC0; VPXOR (11*32)(inp), DD0, DD0 + VMOVDQU AA0, (8*32)(oup); VMOVDQU BB0, (9*32)(oup); VMOVDQU CC0, (10*32)(oup); VMOVDQU DD0, (11*32)(oup) + VPERM2I128 $0x02, AA3, BB3, AA0; VPERM2I128 $0x02, tmpStoreAVX2, DD3, BB0; VPERM2I128 $0x13, AA3, BB3, CC0; VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0 + VPXOR (12*32)(inp), AA0, AA0; VPXOR (13*32)(inp), BB0, BB0; VPXOR (14*32)(inp), CC0, CC0; VPXOR (15*32)(inp), DD0, DD0 + VMOVDQU AA0, (12*32)(oup); VMOVDQU BB0, (13*32)(oup); VMOVDQU CC0, (14*32)(oup); VMOVDQU DD0, (15*32)(oup) + LEAQ (32*16)(inp), inp + SUBQ $(32*16), inl + CMPQ inl, $512 + JG sealAVX2MainLoop + + // Tail can only hash 480 bytes + polyAdd(0*8(oup)) + polyMulAVX2 + polyAdd(2*8(oup)) + polyMulAVX2 + LEAQ 32(oup), oup + + MOVQ $10, itr1 + MOVQ $0, itr2 + CMPQ inl, $128 + JBE sealAVX2Tail128 + CMPQ inl, $256 + JBE sealAVX2Tail256 + CMPQ inl, $384 + JBE sealAVX2Tail384 + JMP sealAVX2Tail512 + +// ---------------------------------------------------------------------------- +// Special optimization for buffers smaller than 193 bytes +seal192AVX2: + // For up to 192 bytes of ciphertext and 64 bytes for the poly key, we process four blocks + VMOVDQA AA0, AA1 + VMOVDQA BB0, BB1 + VMOVDQA CC0, CC1 + VPADDD ·avx2IncMask<>(SB), DD0, DD1 + VMOVDQA AA0, AA2 + VMOVDQA BB0, BB2 + VMOVDQA CC0, CC2 + VMOVDQA DD0, DD2 + VMOVDQA DD1, TT3 + MOVQ $10, itr2 + +sealAVX2192InnerCipherLoop: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1 + DECQ itr2 + JNE sealAVX2192InnerCipherLoop + VPADDD AA2, AA0, AA0; VPADDD AA2, AA1, AA1 + VPADDD BB2, BB0, BB0; VPADDD BB2, BB1, BB1 + VPADDD CC2, CC0, CC0; VPADDD CC2, CC1, CC1 + VPADDD DD2, DD0, DD0; VPADDD TT3, DD1, DD1 + VPERM2I128 $0x02, AA0, BB0, TT0 + + // Clamp and store poly key + VPAND ·polyClampMask<>(SB), TT0, TT0 + VMOVDQA TT0, rsStoreAVX2 + + // Stream for up to 192 bytes + VPERM2I128 $0x13, AA0, BB0, AA0 + VPERM2I128 $0x13, CC0, DD0, BB0 + VPERM2I128 $0x02, AA1, BB1, CC0 + VPERM2I128 $0x02, CC1, DD1, DD0 + VPERM2I128 $0x13, AA1, BB1, AA1 + VPERM2I128 $0x13, CC1, DD1, BB1 + +sealAVX2ShortSeal: + // Hash aad + MOVQ ad_len+80(FP), itr2 + CALL polyHashADInternal<>(SB) + XORQ itr1, itr1 + +sealAVX2SealHash: + // itr1 holds the number of bytes encrypted but not yet hashed + CMPQ itr1, $16 + JB sealAVX2ShortSealLoop + polyAdd(0(oup)) + polyMul + SUBQ $16, itr1 + ADDQ $16, oup + JMP sealAVX2SealHash + +sealAVX2ShortSealLoop: + CMPQ inl, $32 + JB sealAVX2ShortTail32 + SUBQ $32, inl + + // Load for encryption + VPXOR (inp), AA0, AA0 + VMOVDQU AA0, (oup) + LEAQ (1*32)(inp), inp + + // Now can hash + polyAdd(0*8(oup)) + polyMulAVX2 + polyAdd(2*8(oup)) + polyMulAVX2 + LEAQ (1*32)(oup), oup + + // Shift stream left + VMOVDQA BB0, AA0 + VMOVDQA CC0, BB0 + VMOVDQA DD0, CC0 + VMOVDQA AA1, DD0 + VMOVDQA BB1, AA1 + VMOVDQA CC1, BB1 + VMOVDQA DD1, CC1 + VMOVDQA AA2, DD1 + VMOVDQA BB2, AA2 + JMP sealAVX2ShortSealLoop + +sealAVX2ShortTail32: + CMPQ inl, $16 + VMOVDQA A0, A1 + JB sealAVX2ShortDone + + SUBQ $16, inl + + // Load for encryption + VPXOR (inp), A0, T0 + VMOVDQU T0, (oup) + LEAQ (1*16)(inp), inp + + // Hash + polyAdd(0*8(oup)) + polyMulAVX2 + LEAQ (1*16)(oup), oup + VPERM2I128 $0x11, AA0, AA0, AA0 + VMOVDQA A0, A1 + +sealAVX2ShortDone: + VZEROUPPER + JMP sealSSETail + +// ---------------------------------------------------------------------------- +// Special optimization for buffers smaller than 321 bytes +seal320AVX2: + // For up to 320 bytes of ciphertext and 64 bytes for the poly key, we process six blocks + VMOVDQA AA0, AA1; VMOVDQA BB0, BB1; VMOVDQA CC0, CC1; VPADDD ·avx2IncMask<>(SB), DD0, DD1 + VMOVDQA AA0, AA2; VMOVDQA BB0, BB2; VMOVDQA CC0, CC2; VPADDD ·avx2IncMask<>(SB), DD1, DD2 + VMOVDQA BB0, TT1; VMOVDQA CC0, TT2; VMOVDQA DD0, TT3 + MOVQ $10, itr2 + +sealAVX2320InnerCipherLoop: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0) + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0) + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2 + DECQ itr2 + JNE sealAVX2320InnerCipherLoop + + VMOVDQA ·chacha20Constants<>(SB), TT0 + VPADDD TT0, AA0, AA0; VPADDD TT0, AA1, AA1; VPADDD TT0, AA2, AA2 + VPADDD TT1, BB0, BB0; VPADDD TT1, BB1, BB1; VPADDD TT1, BB2, BB2 + VPADDD TT2, CC0, CC0; VPADDD TT2, CC1, CC1; VPADDD TT2, CC2, CC2 + VMOVDQA ·avx2IncMask<>(SB), TT0 + VPADDD TT3, DD0, DD0; VPADDD TT0, TT3, TT3 + VPADDD TT3, DD1, DD1; VPADDD TT0, TT3, TT3 + VPADDD TT3, DD2, DD2 + + // Clamp and store poly key + VPERM2I128 $0x02, AA0, BB0, TT0 + VPAND ·polyClampMask<>(SB), TT0, TT0 + VMOVDQA TT0, rsStoreAVX2 + + // Stream for up to 320 bytes + VPERM2I128 $0x13, AA0, BB0, AA0 + VPERM2I128 $0x13, CC0, DD0, BB0 + VPERM2I128 $0x02, AA1, BB1, CC0 + VPERM2I128 $0x02, CC1, DD1, DD0 + VPERM2I128 $0x13, AA1, BB1, AA1 + VPERM2I128 $0x13, CC1, DD1, BB1 + VPERM2I128 $0x02, AA2, BB2, CC1 + VPERM2I128 $0x02, CC2, DD2, DD1 + VPERM2I128 $0x13, AA2, BB2, AA2 + VPERM2I128 $0x13, CC2, DD2, BB2 + JMP sealAVX2ShortSeal + +// ---------------------------------------------------------------------------- +// Special optimization for the last 128 bytes of ciphertext +sealAVX2Tail128: + // Need to decrypt up to 128 bytes - prepare two blocks + // If we got here after the main loop - there are 512 encrypted bytes waiting to be hashed + // If we got here before the main loop - there are 448 encrpyred bytes waiting to be hashed + VMOVDQA ·chacha20Constants<>(SB), AA0 + VMOVDQA state1StoreAVX2, BB0 + VMOVDQA state2StoreAVX2, CC0 + VMOVDQA ctr3StoreAVX2, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD0 + VMOVDQA DD0, DD1 + +sealAVX2Tail128LoopA: + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + +sealAVX2Tail128LoopB: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0) + polyAdd(0(oup)) + polyMul + VPALIGNR $4, BB0, BB0, BB0 + VPALIGNR $8, CC0, CC0, CC0 + VPALIGNR $12, DD0, DD0, DD0 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0) + polyAdd(16(oup)) + polyMul + LEAQ 32(oup), oup + VPALIGNR $12, BB0, BB0, BB0 + VPALIGNR $8, CC0, CC0, CC0 + VPALIGNR $4, DD0, DD0, DD0 + DECQ itr1 + JG sealAVX2Tail128LoopA + DECQ itr2 + JGE sealAVX2Tail128LoopB + + VPADDD ·chacha20Constants<>(SB), AA0, AA1 + VPADDD state1StoreAVX2, BB0, BB1 + VPADDD state2StoreAVX2, CC0, CC1 + VPADDD DD1, DD0, DD1 + + VPERM2I128 $0x02, AA1, BB1, AA0 + VPERM2I128 $0x02, CC1, DD1, BB0 + VPERM2I128 $0x13, AA1, BB1, CC0 + VPERM2I128 $0x13, CC1, DD1, DD0 + JMP sealAVX2ShortSealLoop + +// ---------------------------------------------------------------------------- +// Special optimization for the last 256 bytes of ciphertext +sealAVX2Tail256: + // Need to decrypt up to 256 bytes - prepare two blocks + // If we got here after the main loop - there are 512 encrypted bytes waiting to be hashed + // If we got here before the main loop - there are 448 encrpyred bytes waiting to be hashed + VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA ·chacha20Constants<>(SB), AA1 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA state1StoreAVX2, BB1 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA state2StoreAVX2, CC1 + VMOVDQA ctr3StoreAVX2, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD1 + VMOVDQA DD0, TT1 + VMOVDQA DD1, TT2 + +sealAVX2Tail256LoopA: + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + +sealAVX2Tail256LoopB: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + polyAdd(0(oup)) + polyMul + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0) + polyAdd(16(oup)) + polyMul + LEAQ 32(oup), oup + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1 + DECQ itr1 + JG sealAVX2Tail256LoopA + DECQ itr2 + JGE sealAVX2Tail256LoopB + + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1 + VPADDD TT1, DD0, DD0; VPADDD TT2, DD1, DD1 + VPERM2I128 $0x02, AA0, BB0, TT0 + VPERM2I128 $0x02, CC0, DD0, TT1 + VPERM2I128 $0x13, AA0, BB0, TT2 + VPERM2I128 $0x13, CC0, DD0, TT3 + VPXOR (0*32)(inp), TT0, TT0; VPXOR (1*32)(inp), TT1, TT1; VPXOR (2*32)(inp), TT2, TT2; VPXOR (3*32)(inp), TT3, TT3 + VMOVDQU TT0, (0*32)(oup); VMOVDQU TT1, (1*32)(oup); VMOVDQU TT2, (2*32)(oup); VMOVDQU TT3, (3*32)(oup) + MOVQ $128, itr1 + LEAQ 128(inp), inp + SUBQ $128, inl + VPERM2I128 $0x02, AA1, BB1, AA0 + VPERM2I128 $0x02, CC1, DD1, BB0 + VPERM2I128 $0x13, AA1, BB1, CC0 + VPERM2I128 $0x13, CC1, DD1, DD0 + + JMP sealAVX2SealHash + +// ---------------------------------------------------------------------------- +// Special optimization for the last 384 bytes of ciphertext +sealAVX2Tail384: + // Need to decrypt up to 384 bytes - prepare two blocks + // If we got here after the main loop - there are 512 encrypted bytes waiting to be hashed + // If we got here before the main loop - there are 448 encrpyred bytes waiting to be hashed + VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2 + VMOVDQA ctr3StoreAVX2, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2 + VMOVDQA DD0, TT1; VMOVDQA DD1, TT2; VMOVDQA DD2, TT3 + +sealAVX2Tail384LoopA: + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + +sealAVX2Tail384LoopB: + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0) + polyAdd(0(oup)) + polyMul + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2 + chachaQR_AVX2(AA0, BB0, CC0, DD0, TT0); chachaQR_AVX2(AA1, BB1, CC1, DD1, TT0); chachaQR_AVX2(AA2, BB2, CC2, DD2, TT0) + polyAdd(16(oup)) + polyMul + LEAQ 32(oup), oup + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2 + DECQ itr1 + JG sealAVX2Tail384LoopA + DECQ itr2 + JGE sealAVX2Tail384LoopB + + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2 + VPADDD TT1, DD0, DD0; VPADDD TT2, DD1, DD1; VPADDD TT3, DD2, DD2 + VPERM2I128 $0x02, AA0, BB0, TT0 + VPERM2I128 $0x02, CC0, DD0, TT1 + VPERM2I128 $0x13, AA0, BB0, TT2 + VPERM2I128 $0x13, CC0, DD0, TT3 + VPXOR (0*32)(inp), TT0, TT0; VPXOR (1*32)(inp), TT1, TT1; VPXOR (2*32)(inp), TT2, TT2; VPXOR (3*32)(inp), TT3, TT3 + VMOVDQU TT0, (0*32)(oup); VMOVDQU TT1, (1*32)(oup); VMOVDQU TT2, (2*32)(oup); VMOVDQU TT3, (3*32)(oup) + VPERM2I128 $0x02, AA1, BB1, TT0 + VPERM2I128 $0x02, CC1, DD1, TT1 + VPERM2I128 $0x13, AA1, BB1, TT2 + VPERM2I128 $0x13, CC1, DD1, TT3 + VPXOR (4*32)(inp), TT0, TT0; VPXOR (5*32)(inp), TT1, TT1; VPXOR (6*32)(inp), TT2, TT2; VPXOR (7*32)(inp), TT3, TT3 + VMOVDQU TT0, (4*32)(oup); VMOVDQU TT1, (5*32)(oup); VMOVDQU TT2, (6*32)(oup); VMOVDQU TT3, (7*32)(oup) + MOVQ $256, itr1 + LEAQ 256(inp), inp + SUBQ $256, inl + VPERM2I128 $0x02, AA2, BB2, AA0 + VPERM2I128 $0x02, CC2, DD2, BB0 + VPERM2I128 $0x13, AA2, BB2, CC0 + VPERM2I128 $0x13, CC2, DD2, DD0 + + JMP sealAVX2SealHash + +// ---------------------------------------------------------------------------- +// Special optimization for the last 512 bytes of ciphertext +sealAVX2Tail512: + // Need to decrypt up to 512 bytes - prepare two blocks + // If we got here after the main loop - there are 512 encrypted bytes waiting to be hashed + // If we got here before the main loop - there are 448 encrpyred bytes waiting to be hashed + VMOVDQA ·chacha20Constants<>(SB), AA0; VMOVDQA AA0, AA1; VMOVDQA AA0, AA2; VMOVDQA AA0, AA3 + VMOVDQA state1StoreAVX2, BB0; VMOVDQA BB0, BB1; VMOVDQA BB0, BB2; VMOVDQA BB0, BB3 + VMOVDQA state2StoreAVX2, CC0; VMOVDQA CC0, CC1; VMOVDQA CC0, CC2; VMOVDQA CC0, CC3 + VMOVDQA ctr3StoreAVX2, DD0 + VPADDD ·avx2IncMask<>(SB), DD0, DD0; VPADDD ·avx2IncMask<>(SB), DD0, DD1; VPADDD ·avx2IncMask<>(SB), DD1, DD2; VPADDD ·avx2IncMask<>(SB), DD2, DD3 + VMOVDQA DD0, ctr0StoreAVX2; VMOVDQA DD1, ctr1StoreAVX2; VMOVDQA DD2, ctr2StoreAVX2; VMOVDQA DD3, ctr3StoreAVX2 + +sealAVX2Tail512LoopA: + polyAdd(0(oup)) + polyMul + LEAQ 16(oup), oup + +sealAVX2Tail512LoopB: + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + polyAdd(0*8(oup)) + polyMulAVX2 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + VPALIGNR $4, BB0, BB0, BB0; VPALIGNR $4, BB1, BB1, BB1; VPALIGNR $4, BB2, BB2, BB2; VPALIGNR $4, BB3, BB3, BB3 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3 + VPALIGNR $12, DD0, DD0, DD0; VPALIGNR $12, DD1, DD1, DD1; VPALIGNR $12, DD2, DD2, DD2; VPALIGNR $12, DD3, DD3, DD3 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol16<>(SB), DD0, DD0; VPSHUFB ·rol16<>(SB), DD1, DD1; VPSHUFB ·rol16<>(SB), DD2, DD2; VPSHUFB ·rol16<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + polyAdd(2*8(oup)) + polyMulAVX2 + LEAQ (4*8)(oup), oup + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $12, BB0, CC3; VPSRLD $20, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $12, BB1, CC3; VPSRLD $20, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $12, BB2, CC3; VPSRLD $20, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $12, BB3, CC3; VPSRLD $20, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + VPADDD BB0, AA0, AA0; VPADDD BB1, AA1, AA1; VPADDD BB2, AA2, AA2; VPADDD BB3, AA3, AA3 + VPXOR AA0, DD0, DD0; VPXOR AA1, DD1, DD1; VPXOR AA2, DD2, DD2; VPXOR AA3, DD3, DD3 + VPSHUFB ·rol8<>(SB), DD0, DD0; VPSHUFB ·rol8<>(SB), DD1, DD1; VPSHUFB ·rol8<>(SB), DD2, DD2; VPSHUFB ·rol8<>(SB), DD3, DD3 + VPADDD DD0, CC0, CC0; VPADDD DD1, CC1, CC1; VPADDD DD2, CC2, CC2; VPADDD DD3, CC3, CC3 + VPXOR CC0, BB0, BB0; VPXOR CC1, BB1, BB1; VPXOR CC2, BB2, BB2; VPXOR CC3, BB3, BB3 + VMOVDQA CC3, tmpStoreAVX2 + VPSLLD $7, BB0, CC3; VPSRLD $25, BB0, BB0; VPXOR CC3, BB0, BB0 + VPSLLD $7, BB1, CC3; VPSRLD $25, BB1, BB1; VPXOR CC3, BB1, BB1 + VPSLLD $7, BB2, CC3; VPSRLD $25, BB2, BB2; VPXOR CC3, BB2, BB2 + VPSLLD $7, BB3, CC3; VPSRLD $25, BB3, BB3; VPXOR CC3, BB3, BB3 + VMOVDQA tmpStoreAVX2, CC3 + VPALIGNR $12, BB0, BB0, BB0; VPALIGNR $12, BB1, BB1, BB1; VPALIGNR $12, BB2, BB2, BB2; VPALIGNR $12, BB3, BB3, BB3 + VPALIGNR $8, CC0, CC0, CC0; VPALIGNR $8, CC1, CC1, CC1; VPALIGNR $8, CC2, CC2, CC2; VPALIGNR $8, CC3, CC3, CC3 + VPALIGNR $4, DD0, DD0, DD0; VPALIGNR $4, DD1, DD1, DD1; VPALIGNR $4, DD2, DD2, DD2; VPALIGNR $4, DD3, DD3, DD3 + + DECQ itr1 + JG sealAVX2Tail512LoopA + DECQ itr2 + JGE sealAVX2Tail512LoopB + + VPADDD ·chacha20Constants<>(SB), AA0, AA0; VPADDD ·chacha20Constants<>(SB), AA1, AA1; VPADDD ·chacha20Constants<>(SB), AA2, AA2; VPADDD ·chacha20Constants<>(SB), AA3, AA3 + VPADDD state1StoreAVX2, BB0, BB0; VPADDD state1StoreAVX2, BB1, BB1; VPADDD state1StoreAVX2, BB2, BB2; VPADDD state1StoreAVX2, BB3, BB3 + VPADDD state2StoreAVX2, CC0, CC0; VPADDD state2StoreAVX2, CC1, CC1; VPADDD state2StoreAVX2, CC2, CC2; VPADDD state2StoreAVX2, CC3, CC3 + VPADDD ctr0StoreAVX2, DD0, DD0; VPADDD ctr1StoreAVX2, DD1, DD1; VPADDD ctr2StoreAVX2, DD2, DD2; VPADDD ctr3StoreAVX2, DD3, DD3 + VMOVDQA CC3, tmpStoreAVX2 + VPERM2I128 $0x02, AA0, BB0, CC3 + VPXOR (0*32)(inp), CC3, CC3 + VMOVDQU CC3, (0*32)(oup) + VPERM2I128 $0x02, CC0, DD0, CC3 + VPXOR (1*32)(inp), CC3, CC3 + VMOVDQU CC3, (1*32)(oup) + VPERM2I128 $0x13, AA0, BB0, CC3 + VPXOR (2*32)(inp), CC3, CC3 + VMOVDQU CC3, (2*32)(oup) + VPERM2I128 $0x13, CC0, DD0, CC3 + VPXOR (3*32)(inp), CC3, CC3 + VMOVDQU CC3, (3*32)(oup) + + VPERM2I128 $0x02, AA1, BB1, AA0 + VPERM2I128 $0x02, CC1, DD1, BB0 + VPERM2I128 $0x13, AA1, BB1, CC0 + VPERM2I128 $0x13, CC1, DD1, DD0 + VPXOR (4*32)(inp), AA0, AA0; VPXOR (5*32)(inp), BB0, BB0; VPXOR (6*32)(inp), CC0, CC0; VPXOR (7*32)(inp), DD0, DD0 + VMOVDQU AA0, (4*32)(oup); VMOVDQU BB0, (5*32)(oup); VMOVDQU CC0, (6*32)(oup); VMOVDQU DD0, (7*32)(oup) + + VPERM2I128 $0x02, AA2, BB2, AA0 + VPERM2I128 $0x02, CC2, DD2, BB0 + VPERM2I128 $0x13, AA2, BB2, CC0 + VPERM2I128 $0x13, CC2, DD2, DD0 + VPXOR (8*32)(inp), AA0, AA0; VPXOR (9*32)(inp), BB0, BB0; VPXOR (10*32)(inp), CC0, CC0; VPXOR (11*32)(inp), DD0, DD0 + VMOVDQU AA0, (8*32)(oup); VMOVDQU BB0, (9*32)(oup); VMOVDQU CC0, (10*32)(oup); VMOVDQU DD0, (11*32)(oup) + + MOVQ $384, itr1 + LEAQ 384(inp), inp + SUBQ $384, inl + VPERM2I128 $0x02, AA3, BB3, AA0 + VPERM2I128 $0x02, tmpStoreAVX2, DD3, BB0 + VPERM2I128 $0x13, AA3, BB3, CC0 + VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0 + + JMP sealAVX2SealHash diff --git a/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_generic.go b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_generic.go new file mode 100644 index 00000000..c2797121 --- /dev/null +++ b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_generic.go @@ -0,0 +1,81 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package chacha20poly1305 + +import ( + "encoding/binary" + + "golang.org/x/crypto/internal/chacha20" + "golang.org/x/crypto/internal/subtle" + "golang.org/x/crypto/poly1305" +) + +func roundTo16(n int) int { + return 16 * ((n + 15) / 16) +} + +func (c *chacha20poly1305) sealGeneric(dst, nonce, plaintext, additionalData []byte) []byte { + ret, out := sliceForAppend(dst, len(plaintext)+poly1305.TagSize) + if subtle.InexactOverlap(out, plaintext) { + panic("chacha20poly1305: invalid buffer overlap") + } + + var polyKey [32]byte + s := chacha20.New(c.key, [3]uint32{ + binary.LittleEndian.Uint32(nonce[0:4]), + binary.LittleEndian.Uint32(nonce[4:8]), + binary.LittleEndian.Uint32(nonce[8:12]), + }) + s.XORKeyStream(polyKey[:], polyKey[:]) + s.Advance() // skip the next 32 bytes + s.XORKeyStream(out, plaintext) + + polyInput := make([]byte, roundTo16(len(additionalData))+roundTo16(len(plaintext))+8+8) + copy(polyInput, additionalData) + copy(polyInput[roundTo16(len(additionalData)):], out[:len(plaintext)]) + binary.LittleEndian.PutUint64(polyInput[len(polyInput)-16:], uint64(len(additionalData))) + binary.LittleEndian.PutUint64(polyInput[len(polyInput)-8:], uint64(len(plaintext))) + + var tag [poly1305.TagSize]byte + poly1305.Sum(&tag, polyInput, &polyKey) + copy(out[len(plaintext):], tag[:]) + + return ret +} + +func (c *chacha20poly1305) openGeneric(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { + var tag [poly1305.TagSize]byte + copy(tag[:], ciphertext[len(ciphertext)-16:]) + ciphertext = ciphertext[:len(ciphertext)-16] + + var polyKey [32]byte + s := chacha20.New(c.key, [3]uint32{ + binary.LittleEndian.Uint32(nonce[0:4]), + binary.LittleEndian.Uint32(nonce[4:8]), + binary.LittleEndian.Uint32(nonce[8:12]), + }) + s.XORKeyStream(polyKey[:], polyKey[:]) + s.Advance() // skip the next 32 bytes + + polyInput := make([]byte, roundTo16(len(additionalData))+roundTo16(len(ciphertext))+8+8) + copy(polyInput, additionalData) + copy(polyInput[roundTo16(len(additionalData)):], ciphertext) + binary.LittleEndian.PutUint64(polyInput[len(polyInput)-16:], uint64(len(additionalData))) + binary.LittleEndian.PutUint64(polyInput[len(polyInput)-8:], uint64(len(ciphertext))) + + ret, out := sliceForAppend(dst, len(ciphertext)) + if subtle.InexactOverlap(out, ciphertext) { + panic("chacha20poly1305: invalid buffer overlap") + } + if !poly1305.Verify(&tag, polyInput, &polyKey) { + for i := range out { + out[i] = 0 + } + return nil, errOpen + } + + s.XORKeyStream(out, ciphertext) + return ret, nil +} diff --git a/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_noasm.go b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_noasm.go new file mode 100644 index 00000000..4c2eb703 --- /dev/null +++ b/vendor/golang.org/x/crypto/chacha20poly1305/chacha20poly1305_noasm.go @@ -0,0 +1,15 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !amd64 !go1.7 gccgo appengine + +package chacha20poly1305 + +func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte { + return c.sealGeneric(dst, nonce, plaintext, additionalData) +} + +func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { + return c.openGeneric(dst, nonce, ciphertext, additionalData) +} diff --git a/vendor/golang.org/x/crypto/chacha20poly1305/xchacha20poly1305.go b/vendor/golang.org/x/crypto/chacha20poly1305/xchacha20poly1305.go new file mode 100644 index 00000000..a02fa571 --- /dev/null +++ b/vendor/golang.org/x/crypto/chacha20poly1305/xchacha20poly1305.go @@ -0,0 +1,104 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package chacha20poly1305 + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "golang.org/x/crypto/internal/chacha20" +) + +type xchacha20poly1305 struct { + key [8]uint32 +} + +// NewX returns a XChaCha20-Poly1305 AEAD that uses the given 256-bit key. +// +// XChaCha20-Poly1305 is a ChaCha20-Poly1305 variant that takes a longer nonce, +// suitable to be generated randomly without risk of collisions. It should be +// preferred when nonce uniqueness cannot be trivially ensured, or whenever +// nonces are randomly generated. +func NewX(key []byte) (cipher.AEAD, error) { + if len(key) != KeySize { + return nil, errors.New("chacha20poly1305: bad key length") + } + ret := new(xchacha20poly1305) + ret.key[0] = binary.LittleEndian.Uint32(key[0:4]) + ret.key[1] = binary.LittleEndian.Uint32(key[4:8]) + ret.key[2] = binary.LittleEndian.Uint32(key[8:12]) + ret.key[3] = binary.LittleEndian.Uint32(key[12:16]) + ret.key[4] = binary.LittleEndian.Uint32(key[16:20]) + ret.key[5] = binary.LittleEndian.Uint32(key[20:24]) + ret.key[6] = binary.LittleEndian.Uint32(key[24:28]) + ret.key[7] = binary.LittleEndian.Uint32(key[28:32]) + return ret, nil +} + +func (*xchacha20poly1305) NonceSize() int { + return NonceSizeX +} + +func (*xchacha20poly1305) Overhead() int { + return 16 +} + +func (x *xchacha20poly1305) Seal(dst, nonce, plaintext, additionalData []byte) []byte { + if len(nonce) != NonceSizeX { + panic("chacha20poly1305: bad nonce length passed to Seal") + } + + // XChaCha20-Poly1305 technically supports a 64-bit counter, so there is no + // size limit. However, since we reuse the ChaCha20-Poly1305 implementation, + // the second half of the counter is not available. This is unlikely to be + // an issue because the cipher.AEAD API requires the entire message to be in + // memory, and the counter overflows at 256 GB. + if uint64(len(plaintext)) > (1<<38)-64 { + panic("chacha20poly1305: plaintext too large") + } + + hNonce := [4]uint32{ + binary.LittleEndian.Uint32(nonce[0:4]), + binary.LittleEndian.Uint32(nonce[4:8]), + binary.LittleEndian.Uint32(nonce[8:12]), + binary.LittleEndian.Uint32(nonce[12:16]), + } + c := &chacha20poly1305{ + key: chacha20.HChaCha20(&x.key, &hNonce), + } + // The first 4 bytes of the final nonce are unused counter space. + cNonce := make([]byte, NonceSize) + copy(cNonce[4:12], nonce[16:24]) + + return c.seal(dst, cNonce[:], plaintext, additionalData) +} + +func (x *xchacha20poly1305) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { + if len(nonce) != NonceSizeX { + panic("chacha20poly1305: bad nonce length passed to Open") + } + if len(ciphertext) < 16 { + return nil, errOpen + } + if uint64(len(ciphertext)) > (1<<38)-48 { + panic("chacha20poly1305: ciphertext too large") + } + + hNonce := [4]uint32{ + binary.LittleEndian.Uint32(nonce[0:4]), + binary.LittleEndian.Uint32(nonce[4:8]), + binary.LittleEndian.Uint32(nonce[8:12]), + binary.LittleEndian.Uint32(nonce[12:16]), + } + c := &chacha20poly1305{ + key: chacha20.HChaCha20(&x.key, &hNonce), + } + // The first 4 bytes of the final nonce are unused counter space. + cNonce := make([]byte, NonceSize) + copy(cNonce[4:12], nonce[16:24]) + + return c.open(dst, cNonce[:], ciphertext, additionalData) +} diff --git a/vendor/golang.org/x/crypto/internal/chacha20/chacha_generic.go b/vendor/golang.org/x/crypto/internal/chacha20/chacha_generic.go new file mode 100644 index 00000000..6570847f --- /dev/null +++ b/vendor/golang.org/x/crypto/internal/chacha20/chacha_generic.go @@ -0,0 +1,264 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ChaCha20 implements the core ChaCha20 function as specified +// in https://tools.ietf.org/html/rfc7539#section-2.3. +package chacha20 + +import ( + "crypto/cipher" + "encoding/binary" + + "golang.org/x/crypto/internal/subtle" +) + +// assert that *Cipher implements cipher.Stream +var _ cipher.Stream = (*Cipher)(nil) + +// Cipher is a stateful instance of ChaCha20 using a particular key +// and nonce. A *Cipher implements the cipher.Stream interface. +type Cipher struct { + key [8]uint32 + counter uint32 // incremented after each block + nonce [3]uint32 + buf [bufSize]byte // buffer for unused keystream bytes + len int // number of unused keystream bytes at end of buf +} + +// New creates a new ChaCha20 stream cipher with the given key and nonce. +// The initial counter value is set to 0. +func New(key [8]uint32, nonce [3]uint32) *Cipher { + return &Cipher{key: key, nonce: nonce} +} + +// ChaCha20 constants spelling "expand 32-byte k" +const ( + j0 uint32 = 0x61707865 + j1 uint32 = 0x3320646e + j2 uint32 = 0x79622d32 + j3 uint32 = 0x6b206574 +) + +func quarterRound(a, b, c, d uint32) (uint32, uint32, uint32, uint32) { + a += b + d ^= a + d = (d << 16) | (d >> 16) + c += d + b ^= c + b = (b << 12) | (b >> 20) + a += b + d ^= a + d = (d << 8) | (d >> 24) + c += d + b ^= c + b = (b << 7) | (b >> 25) + return a, b, c, d +} + +// XORKeyStream XORs each byte in the given slice with a byte from the +// cipher's key stream. Dst and src must overlap entirely or not at all. +// +// If len(dst) < len(src), XORKeyStream will panic. It is acceptable +// to pass a dst bigger than src, and in that case, XORKeyStream will +// only update dst[:len(src)] and will not touch the rest of dst. +// +// Multiple calls to XORKeyStream behave as if the concatenation of +// the src buffers was passed in a single run. That is, Cipher +// maintains state and does not reset at each XORKeyStream call. +func (s *Cipher) XORKeyStream(dst, src []byte) { + if len(dst) < len(src) { + panic("chacha20: output smaller than input") + } + if subtle.InexactOverlap(dst[:len(src)], src) { + panic("chacha20: invalid buffer overlap") + } + + // xor src with buffered keystream first + if s.len != 0 { + buf := s.buf[len(s.buf)-s.len:] + if len(src) < len(buf) { + buf = buf[:len(src)] + } + td, ts := dst[:len(buf)], src[:len(buf)] // BCE hint + for i, b := range buf { + td[i] = ts[i] ^ b + } + s.len -= len(buf) + if s.len != 0 { + return + } + s.buf = [len(s.buf)]byte{} // zero the empty buffer + src = src[len(buf):] + dst = dst[len(buf):] + } + + if len(src) == 0 { + return + } + if haveAsm { + if uint64(len(src))+uint64(s.counter)*64 > (1<<38)-64 { + panic("chacha20: counter overflow") + } + s.xorKeyStreamAsm(dst, src) + return + } + + // set up a 64-byte buffer to pad out the final block if needed + // (hoisted out of the main loop to avoid spills) + rem := len(src) % 64 // length of final block + fin := len(src) - rem // index of final block + if rem > 0 { + copy(s.buf[len(s.buf)-64:], src[fin:]) + } + + // pre-calculate most of the first round + s1, s5, s9, s13 := quarterRound(j1, s.key[1], s.key[5], s.nonce[0]) + s2, s6, s10, s14 := quarterRound(j2, s.key[2], s.key[6], s.nonce[1]) + s3, s7, s11, s15 := quarterRound(j3, s.key[3], s.key[7], s.nonce[2]) + + n := len(src) + src, dst = src[:n:n], dst[:n:n] // BCE hint + for i := 0; i < n; i += 64 { + // calculate the remainder of the first round + s0, s4, s8, s12 := quarterRound(j0, s.key[0], s.key[4], s.counter) + + // execute the second round + x0, x5, x10, x15 := quarterRound(s0, s5, s10, s15) + x1, x6, x11, x12 := quarterRound(s1, s6, s11, s12) + x2, x7, x8, x13 := quarterRound(s2, s7, s8, s13) + x3, x4, x9, x14 := quarterRound(s3, s4, s9, s14) + + // execute the remaining 18 rounds + for i := 0; i < 9; i++ { + x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12) + x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13) + x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14) + x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15) + + x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15) + x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12) + x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13) + x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14) + } + + x0 += j0 + x1 += j1 + x2 += j2 + x3 += j3 + + x4 += s.key[0] + x5 += s.key[1] + x6 += s.key[2] + x7 += s.key[3] + x8 += s.key[4] + x9 += s.key[5] + x10 += s.key[6] + x11 += s.key[7] + + x12 += s.counter + x13 += s.nonce[0] + x14 += s.nonce[1] + x15 += s.nonce[2] + + // increment the counter + s.counter += 1 + if s.counter == 0 { + panic("chacha20: counter overflow") + } + + // pad to 64 bytes if needed + in, out := src[i:], dst[i:] + if i == fin { + // src[fin:] has already been copied into s.buf before + // the main loop + in, out = s.buf[len(s.buf)-64:], s.buf[len(s.buf)-64:] + } + in, out = in[:64], out[:64] // BCE hint + + // XOR the key stream with the source and write out the result + xor(out[0:], in[0:], x0) + xor(out[4:], in[4:], x1) + xor(out[8:], in[8:], x2) + xor(out[12:], in[12:], x3) + xor(out[16:], in[16:], x4) + xor(out[20:], in[20:], x5) + xor(out[24:], in[24:], x6) + xor(out[28:], in[28:], x7) + xor(out[32:], in[32:], x8) + xor(out[36:], in[36:], x9) + xor(out[40:], in[40:], x10) + xor(out[44:], in[44:], x11) + xor(out[48:], in[48:], x12) + xor(out[52:], in[52:], x13) + xor(out[56:], in[56:], x14) + xor(out[60:], in[60:], x15) + } + // copy any trailing bytes out of the buffer and into dst + if rem != 0 { + s.len = 64 - rem + copy(dst[fin:], s.buf[len(s.buf)-64:]) + } +} + +// Advance discards bytes in the key stream until the next 64 byte block +// boundary is reached and updates the counter accordingly. If the key +// stream is already at a block boundary no bytes will be discarded and +// the counter will be unchanged. +func (s *Cipher) Advance() { + s.len -= s.len % 64 + if s.len == 0 { + s.buf = [len(s.buf)]byte{} + } +} + +// XORKeyStream crypts bytes from in to out using the given key and counters. +// In and out must overlap entirely or not at all. Counter contains the raw +// ChaCha20 counter bytes (i.e. block counter followed by nonce). +func XORKeyStream(out, in []byte, counter *[16]byte, key *[32]byte) { + s := Cipher{ + key: [8]uint32{ + binary.LittleEndian.Uint32(key[0:4]), + binary.LittleEndian.Uint32(key[4:8]), + binary.LittleEndian.Uint32(key[8:12]), + binary.LittleEndian.Uint32(key[12:16]), + binary.LittleEndian.Uint32(key[16:20]), + binary.LittleEndian.Uint32(key[20:24]), + binary.LittleEndian.Uint32(key[24:28]), + binary.LittleEndian.Uint32(key[28:32]), + }, + nonce: [3]uint32{ + binary.LittleEndian.Uint32(counter[4:8]), + binary.LittleEndian.Uint32(counter[8:12]), + binary.LittleEndian.Uint32(counter[12:16]), + }, + counter: binary.LittleEndian.Uint32(counter[0:4]), + } + s.XORKeyStream(out, in) +} + +// HChaCha20 uses the ChaCha20 core to generate a derived key from a key and a +// nonce. It should only be used as part of the XChaCha20 construction. +func HChaCha20(key *[8]uint32, nonce *[4]uint32) [8]uint32 { + x0, x1, x2, x3 := j0, j1, j2, j3 + x4, x5, x6, x7 := key[0], key[1], key[2], key[3] + x8, x9, x10, x11 := key[4], key[5], key[6], key[7] + x12, x13, x14, x15 := nonce[0], nonce[1], nonce[2], nonce[3] + + for i := 0; i < 10; i++ { + x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12) + x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13) + x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14) + x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15) + + x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15) + x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12) + x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13) + x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14) + } + + var out [8]uint32 + out[0], out[1], out[2], out[3] = x0, x1, x2, x3 + out[4], out[5], out[6], out[7] = x12, x13, x14, x15 + return out +} diff --git a/vendor/golang.org/x/crypto/internal/chacha20/chacha_noasm.go b/vendor/golang.org/x/crypto/internal/chacha20/chacha_noasm.go new file mode 100644 index 00000000..91520d1d --- /dev/null +++ b/vendor/golang.org/x/crypto/internal/chacha20/chacha_noasm.go @@ -0,0 +1,16 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !s390x gccgo appengine + +package chacha20 + +const ( + bufSize = 64 + haveAsm = false +) + +func (*Cipher) xorKeyStreamAsm(dst, src []byte) { + panic("not implemented") +} diff --git a/vendor/golang.org/x/crypto/internal/chacha20/chacha_s390x.go b/vendor/golang.org/x/crypto/internal/chacha20/chacha_s390x.go new file mode 100644 index 00000000..0c1c671c --- /dev/null +++ b/vendor/golang.org/x/crypto/internal/chacha20/chacha_s390x.go @@ -0,0 +1,30 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build s390x,!gccgo,!appengine + +package chacha20 + +var haveAsm = hasVectorFacility() + +const bufSize = 256 + +// hasVectorFacility reports whether the machine supports the vector +// facility (vx). +// Implementation in asm_s390x.s. +func hasVectorFacility() bool + +// xorKeyStreamVX is an assembly implementation of XORKeyStream. It must only +// be called when the vector facility is available. +// Implementation in asm_s390x.s. +//go:noescape +func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32, buf *[256]byte, len *int) + +func (c *Cipher) xorKeyStreamAsm(dst, src []byte) { + xorKeyStreamVX(dst, src, &c.key, &c.nonce, &c.counter, &c.buf, &c.len) +} + +// EXRL targets, DO NOT CALL! +func mvcSrcToBuf() +func mvcBufToDst() diff --git a/vendor/golang.org/x/crypto/internal/chacha20/chacha_s390x.s b/vendor/golang.org/x/crypto/internal/chacha20/chacha_s390x.s new file mode 100644 index 00000000..98427c5e --- /dev/null +++ b/vendor/golang.org/x/crypto/internal/chacha20/chacha_s390x.s @@ -0,0 +1,283 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build s390x,!gccgo,!appengine + +#include "go_asm.h" +#include "textflag.h" + +// This is an implementation of the ChaCha20 encryption algorithm as +// specified in RFC 7539. It uses vector instructions to compute +// 4 keystream blocks in parallel (256 bytes) which are then XORed +// with the bytes in the input slice. + +GLOBL ·constants<>(SB), RODATA|NOPTR, $32 +// BSWAP: swap bytes in each 4-byte element +DATA ·constants<>+0x00(SB)/4, $0x03020100 +DATA ·constants<>+0x04(SB)/4, $0x07060504 +DATA ·constants<>+0x08(SB)/4, $0x0b0a0908 +DATA ·constants<>+0x0c(SB)/4, $0x0f0e0d0c +// J0: [j0, j1, j2, j3] +DATA ·constants<>+0x10(SB)/4, $0x61707865 +DATA ·constants<>+0x14(SB)/4, $0x3320646e +DATA ·constants<>+0x18(SB)/4, $0x79622d32 +DATA ·constants<>+0x1c(SB)/4, $0x6b206574 + +// EXRL targets: +TEXT ·mvcSrcToBuf(SB), NOFRAME|NOSPLIT, $0 + MVC $1, (R1), (R8) + RET + +TEXT ·mvcBufToDst(SB), NOFRAME|NOSPLIT, $0 + MVC $1, (R8), (R9) + RET + +#define BSWAP V5 +#define J0 V6 +#define KEY0 V7 +#define KEY1 V8 +#define NONCE V9 +#define CTR V10 +#define M0 V11 +#define M1 V12 +#define M2 V13 +#define M3 V14 +#define INC V15 +#define X0 V16 +#define X1 V17 +#define X2 V18 +#define X3 V19 +#define X4 V20 +#define X5 V21 +#define X6 V22 +#define X7 V23 +#define X8 V24 +#define X9 V25 +#define X10 V26 +#define X11 V27 +#define X12 V28 +#define X13 V29 +#define X14 V30 +#define X15 V31 + +#define NUM_ROUNDS 20 + +#define ROUND4(a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3) \ + VAF a1, a0, a0 \ + VAF b1, b0, b0 \ + VAF c1, c0, c0 \ + VAF d1, d0, d0 \ + VX a0, a2, a2 \ + VX b0, b2, b2 \ + VX c0, c2, c2 \ + VX d0, d2, d2 \ + VERLLF $16, a2, a2 \ + VERLLF $16, b2, b2 \ + VERLLF $16, c2, c2 \ + VERLLF $16, d2, d2 \ + VAF a2, a3, a3 \ + VAF b2, b3, b3 \ + VAF c2, c3, c3 \ + VAF d2, d3, d3 \ + VX a3, a1, a1 \ + VX b3, b1, b1 \ + VX c3, c1, c1 \ + VX d3, d1, d1 \ + VERLLF $12, a1, a1 \ + VERLLF $12, b1, b1 \ + VERLLF $12, c1, c1 \ + VERLLF $12, d1, d1 \ + VAF a1, a0, a0 \ + VAF b1, b0, b0 \ + VAF c1, c0, c0 \ + VAF d1, d0, d0 \ + VX a0, a2, a2 \ + VX b0, b2, b2 \ + VX c0, c2, c2 \ + VX d0, d2, d2 \ + VERLLF $8, a2, a2 \ + VERLLF $8, b2, b2 \ + VERLLF $8, c2, c2 \ + VERLLF $8, d2, d2 \ + VAF a2, a3, a3 \ + VAF b2, b3, b3 \ + VAF c2, c3, c3 \ + VAF d2, d3, d3 \ + VX a3, a1, a1 \ + VX b3, b1, b1 \ + VX c3, c1, c1 \ + VX d3, d1, d1 \ + VERLLF $7, a1, a1 \ + VERLLF $7, b1, b1 \ + VERLLF $7, c1, c1 \ + VERLLF $7, d1, d1 + +#define PERMUTE(mask, v0, v1, v2, v3) \ + VPERM v0, v0, mask, v0 \ + VPERM v1, v1, mask, v1 \ + VPERM v2, v2, mask, v2 \ + VPERM v3, v3, mask, v3 + +#define ADDV(x, v0, v1, v2, v3) \ + VAF x, v0, v0 \ + VAF x, v1, v1 \ + VAF x, v2, v2 \ + VAF x, v3, v3 + +#define XORV(off, dst, src, v0, v1, v2, v3) \ + VLM off(src), M0, M3 \ + PERMUTE(BSWAP, v0, v1, v2, v3) \ + VX v0, M0, M0 \ + VX v1, M1, M1 \ + VX v2, M2, M2 \ + VX v3, M3, M3 \ + VSTM M0, M3, off(dst) + +#define SHUFFLE(a, b, c, d, t, u, v, w) \ + VMRHF a, c, t \ // t = {a[0], c[0], a[1], c[1]} + VMRHF b, d, u \ // u = {b[0], d[0], b[1], d[1]} + VMRLF a, c, v \ // v = {a[2], c[2], a[3], c[3]} + VMRLF b, d, w \ // w = {b[2], d[2], b[3], d[3]} + VMRHF t, u, a \ // a = {a[0], b[0], c[0], d[0]} + VMRLF t, u, b \ // b = {a[1], b[1], c[1], d[1]} + VMRHF v, w, c \ // c = {a[2], b[2], c[2], d[2]} + VMRLF v, w, d // d = {a[3], b[3], c[3], d[3]} + +// func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32, buf *[256]byte, len *int) +TEXT ·xorKeyStreamVX(SB), NOSPLIT, $0 + MOVD $·constants<>(SB), R1 + MOVD dst+0(FP), R2 // R2=&dst[0] + LMG src+24(FP), R3, R4 // R3=&src[0] R4=len(src) + MOVD key+48(FP), R5 // R5=key + MOVD nonce+56(FP), R6 // R6=nonce + MOVD counter+64(FP), R7 // R7=counter + MOVD buf+72(FP), R8 // R8=buf + MOVD len+80(FP), R9 // R9=len + + // load BSWAP and J0 + VLM (R1), BSWAP, J0 + + // set up tail buffer + ADD $-1, R4, R12 + MOVBZ R12, R12 + CMPUBEQ R12, $255, aligned + MOVD R4, R1 + AND $~255, R1 + MOVD $(R3)(R1*1), R1 + EXRL $·mvcSrcToBuf(SB), R12 + MOVD $255, R0 + SUB R12, R0 + MOVD R0, (R9) // update len + +aligned: + // setup + MOVD $95, R0 + VLM (R5), KEY0, KEY1 + VLL R0, (R6), NONCE + VZERO M0 + VLEIB $7, $32, M0 + VSRLB M0, NONCE, NONCE + + // initialize counter values + VLREPF (R7), CTR + VZERO INC + VLEIF $1, $1, INC + VLEIF $2, $2, INC + VLEIF $3, $3, INC + VAF INC, CTR, CTR + VREPIF $4, INC + +chacha: + VREPF $0, J0, X0 + VREPF $1, J0, X1 + VREPF $2, J0, X2 + VREPF $3, J0, X3 + VREPF $0, KEY0, X4 + VREPF $1, KEY0, X5 + VREPF $2, KEY0, X6 + VREPF $3, KEY0, X7 + VREPF $0, KEY1, X8 + VREPF $1, KEY1, X9 + VREPF $2, KEY1, X10 + VREPF $3, KEY1, X11 + VLR CTR, X12 + VREPF $1, NONCE, X13 + VREPF $2, NONCE, X14 + VREPF $3, NONCE, X15 + + MOVD $(NUM_ROUNDS/2), R1 + +loop: + ROUND4(X0, X4, X12, X8, X1, X5, X13, X9, X2, X6, X14, X10, X3, X7, X15, X11) + ROUND4(X0, X5, X15, X10, X1, X6, X12, X11, X2, X7, X13, X8, X3, X4, X14, X9) + + ADD $-1, R1 + BNE loop + + // decrement length + ADD $-256, R4 + BLT tail + +continue: + // rearrange vectors + SHUFFLE(X0, X1, X2, X3, M0, M1, M2, M3) + ADDV(J0, X0, X1, X2, X3) + SHUFFLE(X4, X5, X6, X7, M0, M1, M2, M3) + ADDV(KEY0, X4, X5, X6, X7) + SHUFFLE(X8, X9, X10, X11, M0, M1, M2, M3) + ADDV(KEY1, X8, X9, X10, X11) + VAF CTR, X12, X12 + SHUFFLE(X12, X13, X14, X15, M0, M1, M2, M3) + ADDV(NONCE, X12, X13, X14, X15) + + // increment counters + VAF INC, CTR, CTR + + // xor keystream with plaintext + XORV(0*64, R2, R3, X0, X4, X8, X12) + XORV(1*64, R2, R3, X1, X5, X9, X13) + XORV(2*64, R2, R3, X2, X6, X10, X14) + XORV(3*64, R2, R3, X3, X7, X11, X15) + + // increment pointers + MOVD $256(R2), R2 + MOVD $256(R3), R3 + + CMPBNE R4, $0, chacha + CMPUBEQ R12, $255, return + EXRL $·mvcBufToDst(SB), R12 // len was updated during setup + +return: + VSTEF $0, CTR, (R7) + RET + +tail: + MOVD R2, R9 + MOVD R8, R2 + MOVD R8, R3 + MOVD $0, R4 + JMP continue + +// func hasVectorFacility() bool +TEXT ·hasVectorFacility(SB), NOSPLIT, $24-1 + MOVD $x-24(SP), R1 + XC $24, 0(R1), 0(R1) // clear the storage + MOVD $2, R0 // R0 is the number of double words stored -1 + WORD $0xB2B01000 // STFLE 0(R1) + XOR R0, R0 // reset the value of R0 + MOVBZ z-8(SP), R1 + AND $0x40, R1 + BEQ novector + +vectorinstalled: + // check if the vector instruction has been enabled + VLEIB $0, $0xF, V16 + VLGVB $0, V16, R1 + CMPBNE R1, $0xF, novector + MOVB $1, ret+0(FP) // have vx + RET + +novector: + MOVB $0, ret+0(FP) // no vx + RET diff --git a/vendor/golang.org/x/crypto/internal/chacha20/xor.go b/vendor/golang.org/x/crypto/internal/chacha20/xor.go new file mode 100644 index 00000000..9c5ba0b3 --- /dev/null +++ b/vendor/golang.org/x/crypto/internal/chacha20/xor.go @@ -0,0 +1,43 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found src the LICENSE file. + +package chacha20 + +import ( + "runtime" +) + +// Platforms that have fast unaligned 32-bit little endian accesses. +const unaligned = runtime.GOARCH == "386" || + runtime.GOARCH == "amd64" || + runtime.GOARCH == "arm64" || + runtime.GOARCH == "ppc64le" || + runtime.GOARCH == "s390x" + +// xor reads a little endian uint32 from src, XORs it with u and +// places the result in little endian byte order in dst. +func xor(dst, src []byte, u uint32) { + _, _ = src[3], dst[3] // eliminate bounds checks + if unaligned { + // The compiler should optimize this code into + // 32-bit unaligned little endian loads and stores. + // TODO: delete once the compiler does a reliably + // good job with the generic code below. + // See issue #25111 for more details. + v := uint32(src[0]) + v |= uint32(src[1]) << 8 + v |= uint32(src[2]) << 16 + v |= uint32(src[3]) << 24 + v ^= u + dst[0] = byte(v) + dst[1] = byte(v >> 8) + dst[2] = byte(v >> 16) + dst[3] = byte(v >> 24) + } else { + dst[0] = src[0] ^ byte(u) + dst[1] = src[1] ^ byte(u>>8) + dst[2] = src[2] ^ byte(u>>16) + dst[3] = src[3] ^ byte(u>>24) + } +} diff --git a/vendor/golang.org/x/crypto/internal/subtle/aliasing.go b/vendor/golang.org/x/crypto/internal/subtle/aliasing.go new file mode 100644 index 00000000..f38797bf --- /dev/null +++ b/vendor/golang.org/x/crypto/internal/subtle/aliasing.go @@ -0,0 +1,32 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine + +// Package subtle implements functions that are often useful in cryptographic +// code but require careful thought to use correctly. +package subtle // import "golang.org/x/crypto/internal/subtle" + +import "unsafe" + +// AnyOverlap reports whether x and y share memory at any (not necessarily +// corresponding) index. The memory beyond the slice length is ignored. +func AnyOverlap(x, y []byte) bool { + return len(x) > 0 && len(y) > 0 && + uintptr(unsafe.Pointer(&x[0])) <= uintptr(unsafe.Pointer(&y[len(y)-1])) && + uintptr(unsafe.Pointer(&y[0])) <= uintptr(unsafe.Pointer(&x[len(x)-1])) +} + +// InexactOverlap reports whether x and y share memory at any non-corresponding +// index. The memory beyond the slice length is ignored. Note that x and y can +// have different lengths and still not have any inexact overlap. +// +// InexactOverlap can be used to implement the requirements of the crypto/cipher +// AEAD, Block, BlockMode and Stream interfaces. +func InexactOverlap(x, y []byte) bool { + if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] { + return false + } + return AnyOverlap(x, y) +} diff --git a/vendor/golang.org/x/crypto/internal/subtle/aliasing_appengine.go b/vendor/golang.org/x/crypto/internal/subtle/aliasing_appengine.go new file mode 100644 index 00000000..0cc4a8a6 --- /dev/null +++ b/vendor/golang.org/x/crypto/internal/subtle/aliasing_appengine.go @@ -0,0 +1,35 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build appengine + +// Package subtle implements functions that are often useful in cryptographic +// code but require careful thought to use correctly. +package subtle // import "golang.org/x/crypto/internal/subtle" + +// This is the Google App Engine standard variant based on reflect +// because the unsafe package and cgo are disallowed. + +import "reflect" + +// AnyOverlap reports whether x and y share memory at any (not necessarily +// corresponding) index. The memory beyond the slice length is ignored. +func AnyOverlap(x, y []byte) bool { + return len(x) > 0 && len(y) > 0 && + reflect.ValueOf(&x[0]).Pointer() <= reflect.ValueOf(&y[len(y)-1]).Pointer() && + reflect.ValueOf(&y[0]).Pointer() <= reflect.ValueOf(&x[len(x)-1]).Pointer() +} + +// InexactOverlap reports whether x and y share memory at any non-corresponding +// index. The memory beyond the slice length is ignored. Note that x and y can +// have different lengths and still not have any inexact overlap. +// +// InexactOverlap can be used to implement the requirements of the crypto/cipher +// AEAD, Block, BlockMode and Stream interfaces. +func InexactOverlap(x, y []byte) bool { + if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] { + return false + } + return AnyOverlap(x, y) +} diff --git a/vendor/golang.org/x/crypto/poly1305/poly1305.go b/vendor/golang.org/x/crypto/poly1305/poly1305.go new file mode 100644 index 00000000..f562fa57 --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/poly1305.go @@ -0,0 +1,33 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package poly1305 implements Poly1305 one-time message authentication code as +specified in https://cr.yp.to/mac/poly1305-20050329.pdf. + +Poly1305 is a fast, one-time authentication function. It is infeasible for an +attacker to generate an authenticator for a message without the key. However, a +key must only be used for a single message. Authenticating two different +messages with the same key allows an attacker to forge authenticators for other +messages with the same key. + +Poly1305 was originally coupled with AES in order to make Poly1305-AES. AES was +used with a fixed key in order to generate one-time keys from an nonce. +However, in this package AES isn't used and the one-time key is specified +directly. +*/ +package poly1305 // import "golang.org/x/crypto/poly1305" + +import "crypto/subtle" + +// TagSize is the size, in bytes, of a poly1305 authenticator. +const TagSize = 16 + +// Verify returns true if mac is a valid authenticator for m with the given +// key. +func Verify(mac *[16]byte, m []byte, key *[32]byte) bool { + var tmp [16]byte + Sum(&tmp, m, key) + return subtle.ConstantTimeCompare(tmp[:], mac[:]) == 1 +} diff --git a/vendor/golang.org/x/crypto/poly1305/sum_amd64.go b/vendor/golang.org/x/crypto/poly1305/sum_amd64.go new file mode 100644 index 00000000..4dd72fe7 --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_amd64.go @@ -0,0 +1,22 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build amd64,!gccgo,!appengine + +package poly1305 + +// This function is implemented in sum_amd64.s +//go:noescape +func poly1305(out *[16]byte, m *byte, mlen uint64, key *[32]byte) + +// Sum generates an authenticator for m using a one-time key and puts the +// 16-byte result into out. Authenticating two different messages with the same +// key allows an attacker to forge messages at will. +func Sum(out *[16]byte, m []byte, key *[32]byte) { + var mPtr *byte + if len(m) > 0 { + mPtr = &m[0] + } + poly1305(out, mPtr, uint64(len(m)), key) +} diff --git a/vendor/golang.org/x/crypto/poly1305/sum_amd64.s b/vendor/golang.org/x/crypto/poly1305/sum_amd64.s new file mode 100644 index 00000000..2edae638 --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_amd64.s @@ -0,0 +1,125 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build amd64,!gccgo,!appengine + +#include "textflag.h" + +#define POLY1305_ADD(msg, h0, h1, h2) \ + ADDQ 0(msg), h0; \ + ADCQ 8(msg), h1; \ + ADCQ $1, h2; \ + LEAQ 16(msg), msg + +#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3) \ + MOVQ r0, AX; \ + MULQ h0; \ + MOVQ AX, t0; \ + MOVQ DX, t1; \ + MOVQ r0, AX; \ + MULQ h1; \ + ADDQ AX, t1; \ + ADCQ $0, DX; \ + MOVQ r0, t2; \ + IMULQ h2, t2; \ + ADDQ DX, t2; \ + \ + MOVQ r1, AX; \ + MULQ h0; \ + ADDQ AX, t1; \ + ADCQ $0, DX; \ + MOVQ DX, h0; \ + MOVQ r1, t3; \ + IMULQ h2, t3; \ + MOVQ r1, AX; \ + MULQ h1; \ + ADDQ AX, t2; \ + ADCQ DX, t3; \ + ADDQ h0, t2; \ + ADCQ $0, t3; \ + \ + MOVQ t0, h0; \ + MOVQ t1, h1; \ + MOVQ t2, h2; \ + ANDQ $3, h2; \ + MOVQ t2, t0; \ + ANDQ $0xFFFFFFFFFFFFFFFC, t0; \ + ADDQ t0, h0; \ + ADCQ t3, h1; \ + ADCQ $0, h2; \ + SHRQ $2, t3, t2; \ + SHRQ $2, t3; \ + ADDQ t2, h0; \ + ADCQ t3, h1; \ + ADCQ $0, h2 + +DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF +DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC +GLOBL ·poly1305Mask<>(SB), RODATA, $16 + +// func poly1305(out *[16]byte, m *byte, mlen uint64, key *[32]key) +TEXT ·poly1305(SB), $0-32 + MOVQ out+0(FP), DI + MOVQ m+8(FP), SI + MOVQ mlen+16(FP), R15 + MOVQ key+24(FP), AX + + MOVQ 0(AX), R11 + MOVQ 8(AX), R12 + ANDQ ·poly1305Mask<>(SB), R11 // r0 + ANDQ ·poly1305Mask<>+8(SB), R12 // r1 + XORQ R8, R8 // h0 + XORQ R9, R9 // h1 + XORQ R10, R10 // h2 + + CMPQ R15, $16 + JB bytes_between_0_and_15 + +loop: + POLY1305_ADD(SI, R8, R9, R10) + +multiply: + POLY1305_MUL(R8, R9, R10, R11, R12, BX, CX, R13, R14) + SUBQ $16, R15 + CMPQ R15, $16 + JAE loop + +bytes_between_0_and_15: + TESTQ R15, R15 + JZ done + MOVQ $1, BX + XORQ CX, CX + XORQ R13, R13 + ADDQ R15, SI + +flush_buffer: + SHLQ $8, BX, CX + SHLQ $8, BX + MOVB -1(SI), R13 + XORQ R13, BX + DECQ SI + DECQ R15 + JNZ flush_buffer + + ADDQ BX, R8 + ADCQ CX, R9 + ADCQ $0, R10 + MOVQ $16, R15 + JMP multiply + +done: + MOVQ R8, AX + MOVQ R9, BX + SUBQ $0xFFFFFFFFFFFFFFFB, AX + SBBQ $0xFFFFFFFFFFFFFFFF, BX + SBBQ $3, R10 + CMOVQCS R8, AX + CMOVQCS R9, BX + MOVQ key+24(FP), R8 + ADDQ 16(R8), AX + ADCQ 24(R8), BX + + MOVQ AX, 0(DI) + MOVQ BX, 8(DI) + RET diff --git a/vendor/golang.org/x/crypto/poly1305/sum_arm.go b/vendor/golang.org/x/crypto/poly1305/sum_arm.go new file mode 100644 index 00000000..5dc321c2 --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_arm.go @@ -0,0 +1,22 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm,!gccgo,!appengine,!nacl + +package poly1305 + +// This function is implemented in sum_arm.s +//go:noescape +func poly1305_auth_armv6(out *[16]byte, m *byte, mlen uint32, key *[32]byte) + +// Sum generates an authenticator for m using a one-time key and puts the +// 16-byte result into out. Authenticating two different messages with the same +// key allows an attacker to forge messages at will. +func Sum(out *[16]byte, m []byte, key *[32]byte) { + var mPtr *byte + if len(m) > 0 { + mPtr = &m[0] + } + poly1305_auth_armv6(out, mPtr, uint32(len(m)), key) +} diff --git a/vendor/golang.org/x/crypto/poly1305/sum_arm.s b/vendor/golang.org/x/crypto/poly1305/sum_arm.s new file mode 100644 index 00000000..f70b4ac4 --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_arm.s @@ -0,0 +1,427 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build arm,!gccgo,!appengine,!nacl + +#include "textflag.h" + +// This code was translated into a form compatible with 5a from the public +// domain source by Andrew Moon: github.com/floodyberry/poly1305-opt/blob/master/app/extensions/poly1305. + +DATA ·poly1305_init_constants_armv6<>+0x00(SB)/4, $0x3ffffff +DATA ·poly1305_init_constants_armv6<>+0x04(SB)/4, $0x3ffff03 +DATA ·poly1305_init_constants_armv6<>+0x08(SB)/4, $0x3ffc0ff +DATA ·poly1305_init_constants_armv6<>+0x0c(SB)/4, $0x3f03fff +DATA ·poly1305_init_constants_armv6<>+0x10(SB)/4, $0x00fffff +GLOBL ·poly1305_init_constants_armv6<>(SB), 8, $20 + +// Warning: the linker may use R11 to synthesize certain instructions. Please +// take care and verify that no synthetic instructions use it. + +TEXT poly1305_init_ext_armv6<>(SB), NOSPLIT, $0 + // Needs 16 bytes of stack and 64 bytes of space pointed to by R0. (It + // might look like it's only 60 bytes of space but the final four bytes + // will be written by another function.) We need to skip over four + // bytes of stack because that's saving the value of 'g'. + ADD $4, R13, R8 + MOVM.IB [R4-R7], (R8) + MOVM.IA.W (R1), [R2-R5] + MOVW $·poly1305_init_constants_armv6<>(SB), R7 + MOVW R2, R8 + MOVW R2>>26, R9 + MOVW R3>>20, g + MOVW R4>>14, R11 + MOVW R5>>8, R12 + ORR R3<<6, R9, R9 + ORR R4<<12, g, g + ORR R5<<18, R11, R11 + MOVM.IA (R7), [R2-R6] + AND R8, R2, R2 + AND R9, R3, R3 + AND g, R4, R4 + AND R11, R5, R5 + AND R12, R6, R6 + MOVM.IA.W [R2-R6], (R0) + EOR R2, R2, R2 + EOR R3, R3, R3 + EOR R4, R4, R4 + EOR R5, R5, R5 + EOR R6, R6, R6 + MOVM.IA.W [R2-R6], (R0) + MOVM.IA.W (R1), [R2-R5] + MOVM.IA [R2-R6], (R0) + ADD $20, R13, R0 + MOVM.DA (R0), [R4-R7] + RET + +#define MOVW_UNALIGNED(Rsrc, Rdst, Rtmp, offset) \ + MOVBU (offset+0)(Rsrc), Rtmp; \ + MOVBU Rtmp, (offset+0)(Rdst); \ + MOVBU (offset+1)(Rsrc), Rtmp; \ + MOVBU Rtmp, (offset+1)(Rdst); \ + MOVBU (offset+2)(Rsrc), Rtmp; \ + MOVBU Rtmp, (offset+2)(Rdst); \ + MOVBU (offset+3)(Rsrc), Rtmp; \ + MOVBU Rtmp, (offset+3)(Rdst) + +TEXT poly1305_blocks_armv6<>(SB), NOSPLIT, $0 + // Needs 24 bytes of stack for saved registers and then 88 bytes of + // scratch space after that. We assume that 24 bytes at (R13) have + // already been used: four bytes for the link register saved in the + // prelude of poly1305_auth_armv6, four bytes for saving the value of g + // in that function and 16 bytes of scratch space used around + // poly1305_finish_ext_armv6_skip1. + ADD $24, R13, R12 + MOVM.IB [R4-R8, R14], (R12) + MOVW R0, 88(R13) + MOVW R1, 92(R13) + MOVW R2, 96(R13) + MOVW R1, R14 + MOVW R2, R12 + MOVW 56(R0), R8 + WORD $0xe1180008 // TST R8, R8 not working see issue 5921 + EOR R6, R6, R6 + MOVW.EQ $(1<<24), R6 + MOVW R6, 84(R13) + ADD $116, R13, g + MOVM.IA (R0), [R0-R9] + MOVM.IA [R0-R4], (g) + CMP $16, R12 + BLO poly1305_blocks_armv6_done + +poly1305_blocks_armv6_mainloop: + WORD $0xe31e0003 // TST R14, #3 not working see issue 5921 + BEQ poly1305_blocks_armv6_mainloop_aligned + ADD $100, R13, g + MOVW_UNALIGNED(R14, g, R0, 0) + MOVW_UNALIGNED(R14, g, R0, 4) + MOVW_UNALIGNED(R14, g, R0, 8) + MOVW_UNALIGNED(R14, g, R0, 12) + MOVM.IA (g), [R0-R3] + ADD $16, R14 + B poly1305_blocks_armv6_mainloop_loaded + +poly1305_blocks_armv6_mainloop_aligned: + MOVM.IA.W (R14), [R0-R3] + +poly1305_blocks_armv6_mainloop_loaded: + MOVW R0>>26, g + MOVW R1>>20, R11 + MOVW R2>>14, R12 + MOVW R14, 92(R13) + MOVW R3>>8, R4 + ORR R1<<6, g, g + ORR R2<<12, R11, R11 + ORR R3<<18, R12, R12 + BIC $0xfc000000, R0, R0 + BIC $0xfc000000, g, g + MOVW 84(R13), R3 + BIC $0xfc000000, R11, R11 + BIC $0xfc000000, R12, R12 + ADD R0, R5, R5 + ADD g, R6, R6 + ORR R3, R4, R4 + ADD R11, R7, R7 + ADD $116, R13, R14 + ADD R12, R8, R8 + ADD R4, R9, R9 + MOVM.IA (R14), [R0-R4] + MULLU R4, R5, (R11, g) + MULLU R3, R5, (R14, R12) + MULALU R3, R6, (R11, g) + MULALU R2, R6, (R14, R12) + MULALU R2, R7, (R11, g) + MULALU R1, R7, (R14, R12) + ADD R4<<2, R4, R4 + ADD R3<<2, R3, R3 + MULALU R1, R8, (R11, g) + MULALU R0, R8, (R14, R12) + MULALU R0, R9, (R11, g) + MULALU R4, R9, (R14, R12) + MOVW g, 76(R13) + MOVW R11, 80(R13) + MOVW R12, 68(R13) + MOVW R14, 72(R13) + MULLU R2, R5, (R11, g) + MULLU R1, R5, (R14, R12) + MULALU R1, R6, (R11, g) + MULALU R0, R6, (R14, R12) + MULALU R0, R7, (R11, g) + MULALU R4, R7, (R14, R12) + ADD R2<<2, R2, R2 + ADD R1<<2, R1, R1 + MULALU R4, R8, (R11, g) + MULALU R3, R8, (R14, R12) + MULALU R3, R9, (R11, g) + MULALU R2, R9, (R14, R12) + MOVW g, 60(R13) + MOVW R11, 64(R13) + MOVW R12, 52(R13) + MOVW R14, 56(R13) + MULLU R0, R5, (R11, g) + MULALU R4, R6, (R11, g) + MULALU R3, R7, (R11, g) + MULALU R2, R8, (R11, g) + MULALU R1, R9, (R11, g) + ADD $52, R13, R0 + MOVM.IA (R0), [R0-R7] + MOVW g>>26, R12 + MOVW R4>>26, R14 + ORR R11<<6, R12, R12 + ORR R5<<6, R14, R14 + BIC $0xfc000000, g, g + BIC $0xfc000000, R4, R4 + ADD.S R12, R0, R0 + ADC $0, R1, R1 + ADD.S R14, R6, R6 + ADC $0, R7, R7 + MOVW R0>>26, R12 + MOVW R6>>26, R14 + ORR R1<<6, R12, R12 + ORR R7<<6, R14, R14 + BIC $0xfc000000, R0, R0 + BIC $0xfc000000, R6, R6 + ADD R14<<2, R14, R14 + ADD.S R12, R2, R2 + ADC $0, R3, R3 + ADD R14, g, g + MOVW R2>>26, R12 + MOVW g>>26, R14 + ORR R3<<6, R12, R12 + BIC $0xfc000000, g, R5 + BIC $0xfc000000, R2, R7 + ADD R12, R4, R4 + ADD R14, R0, R0 + MOVW R4>>26, R12 + BIC $0xfc000000, R4, R8 + ADD R12, R6, R9 + MOVW 96(R13), R12 + MOVW 92(R13), R14 + MOVW R0, R6 + CMP $32, R12 + SUB $16, R12, R12 + MOVW R12, 96(R13) + BHS poly1305_blocks_armv6_mainloop + +poly1305_blocks_armv6_done: + MOVW 88(R13), R12 + MOVW R5, 20(R12) + MOVW R6, 24(R12) + MOVW R7, 28(R12) + MOVW R8, 32(R12) + MOVW R9, 36(R12) + ADD $48, R13, R0 + MOVM.DA (R0), [R4-R8, R14] + RET + +#define MOVHUP_UNALIGNED(Rsrc, Rdst, Rtmp) \ + MOVBU.P 1(Rsrc), Rtmp; \ + MOVBU.P Rtmp, 1(Rdst); \ + MOVBU.P 1(Rsrc), Rtmp; \ + MOVBU.P Rtmp, 1(Rdst) + +#define MOVWP_UNALIGNED(Rsrc, Rdst, Rtmp) \ + MOVHUP_UNALIGNED(Rsrc, Rdst, Rtmp); \ + MOVHUP_UNALIGNED(Rsrc, Rdst, Rtmp) + +// func poly1305_auth_armv6(out *[16]byte, m *byte, mlen uint32, key *[32]key) +TEXT ·poly1305_auth_armv6(SB), $196-16 + // The value 196, just above, is the sum of 64 (the size of the context + // structure) and 132 (the amount of stack needed). + // + // At this point, the stack pointer (R13) has been moved down. It + // points to the saved link register and there's 196 bytes of free + // space above it. + // + // The stack for this function looks like: + // + // +--------------------- + // | + // | 64 bytes of context structure + // | + // +--------------------- + // | + // | 112 bytes for poly1305_blocks_armv6 + // | + // +--------------------- + // | 16 bytes of final block, constructed at + // | poly1305_finish_ext_armv6_skip8 + // +--------------------- + // | four bytes of saved 'g' + // +--------------------- + // | lr, saved by prelude <- R13 points here + // +--------------------- + MOVW g, 4(R13) + + MOVW out+0(FP), R4 + MOVW m+4(FP), R5 + MOVW mlen+8(FP), R6 + MOVW key+12(FP), R7 + + ADD $136, R13, R0 // 136 = 4 + 4 + 16 + 112 + MOVW R7, R1 + + // poly1305_init_ext_armv6 will write to the stack from R13+4, but + // that's ok because none of the other values have been written yet. + BL poly1305_init_ext_armv6<>(SB) + BIC.S $15, R6, R2 + BEQ poly1305_auth_armv6_noblocks + ADD $136, R13, R0 + MOVW R5, R1 + ADD R2, R5, R5 + SUB R2, R6, R6 + BL poly1305_blocks_armv6<>(SB) + +poly1305_auth_armv6_noblocks: + ADD $136, R13, R0 + MOVW R5, R1 + MOVW R6, R2 + MOVW R4, R3 + + MOVW R0, R5 + MOVW R1, R6 + MOVW R2, R7 + MOVW R3, R8 + AND.S R2, R2, R2 + BEQ poly1305_finish_ext_armv6_noremaining + EOR R0, R0 + ADD $8, R13, R9 // 8 = offset to 16 byte scratch space + MOVW R0, (R9) + MOVW R0, 4(R9) + MOVW R0, 8(R9) + MOVW R0, 12(R9) + WORD $0xe3110003 // TST R1, #3 not working see issue 5921 + BEQ poly1305_finish_ext_armv6_aligned + WORD $0xe3120008 // TST R2, #8 not working see issue 5921 + BEQ poly1305_finish_ext_armv6_skip8 + MOVWP_UNALIGNED(R1, R9, g) + MOVWP_UNALIGNED(R1, R9, g) + +poly1305_finish_ext_armv6_skip8: + WORD $0xe3120004 // TST $4, R2 not working see issue 5921 + BEQ poly1305_finish_ext_armv6_skip4 + MOVWP_UNALIGNED(R1, R9, g) + +poly1305_finish_ext_armv6_skip4: + WORD $0xe3120002 // TST $2, R2 not working see issue 5921 + BEQ poly1305_finish_ext_armv6_skip2 + MOVHUP_UNALIGNED(R1, R9, g) + B poly1305_finish_ext_armv6_skip2 + +poly1305_finish_ext_armv6_aligned: + WORD $0xe3120008 // TST R2, #8 not working see issue 5921 + BEQ poly1305_finish_ext_armv6_skip8_aligned + MOVM.IA.W (R1), [g-R11] + MOVM.IA.W [g-R11], (R9) + +poly1305_finish_ext_armv6_skip8_aligned: + WORD $0xe3120004 // TST $4, R2 not working see issue 5921 + BEQ poly1305_finish_ext_armv6_skip4_aligned + MOVW.P 4(R1), g + MOVW.P g, 4(R9) + +poly1305_finish_ext_armv6_skip4_aligned: + WORD $0xe3120002 // TST $2, R2 not working see issue 5921 + BEQ poly1305_finish_ext_armv6_skip2 + MOVHU.P 2(R1), g + MOVH.P g, 2(R9) + +poly1305_finish_ext_armv6_skip2: + WORD $0xe3120001 // TST $1, R2 not working see issue 5921 + BEQ poly1305_finish_ext_armv6_skip1 + MOVBU.P 1(R1), g + MOVBU.P g, 1(R9) + +poly1305_finish_ext_armv6_skip1: + MOVW $1, R11 + MOVBU R11, 0(R9) + MOVW R11, 56(R5) + MOVW R5, R0 + ADD $8, R13, R1 + MOVW $16, R2 + BL poly1305_blocks_armv6<>(SB) + +poly1305_finish_ext_armv6_noremaining: + MOVW 20(R5), R0 + MOVW 24(R5), R1 + MOVW 28(R5), R2 + MOVW 32(R5), R3 + MOVW 36(R5), R4 + MOVW R4>>26, R12 + BIC $0xfc000000, R4, R4 + ADD R12<<2, R12, R12 + ADD R12, R0, R0 + MOVW R0>>26, R12 + BIC $0xfc000000, R0, R0 + ADD R12, R1, R1 + MOVW R1>>26, R12 + BIC $0xfc000000, R1, R1 + ADD R12, R2, R2 + MOVW R2>>26, R12 + BIC $0xfc000000, R2, R2 + ADD R12, R3, R3 + MOVW R3>>26, R12 + BIC $0xfc000000, R3, R3 + ADD R12, R4, R4 + ADD $5, R0, R6 + MOVW R6>>26, R12 + BIC $0xfc000000, R6, R6 + ADD R12, R1, R7 + MOVW R7>>26, R12 + BIC $0xfc000000, R7, R7 + ADD R12, R2, g + MOVW g>>26, R12 + BIC $0xfc000000, g, g + ADD R12, R3, R11 + MOVW $-(1<<26), R12 + ADD R11>>26, R12, R12 + BIC $0xfc000000, R11, R11 + ADD R12, R4, R9 + MOVW R9>>31, R12 + SUB $1, R12 + AND R12, R6, R6 + AND R12, R7, R7 + AND R12, g, g + AND R12, R11, R11 + AND R12, R9, R9 + MVN R12, R12 + AND R12, R0, R0 + AND R12, R1, R1 + AND R12, R2, R2 + AND R12, R3, R3 + AND R12, R4, R4 + ORR R6, R0, R0 + ORR R7, R1, R1 + ORR g, R2, R2 + ORR R11, R3, R3 + ORR R9, R4, R4 + ORR R1<<26, R0, R0 + MOVW R1>>6, R1 + ORR R2<<20, R1, R1 + MOVW R2>>12, R2 + ORR R3<<14, R2, R2 + MOVW R3>>18, R3 + ORR R4<<8, R3, R3 + MOVW 40(R5), R6 + MOVW 44(R5), R7 + MOVW 48(R5), g + MOVW 52(R5), R11 + ADD.S R6, R0, R0 + ADC.S R7, R1, R1 + ADC.S g, R2, R2 + ADC.S R11, R3, R3 + MOVM.IA [R0-R3], (R8) + MOVW R5, R12 + EOR R0, R0, R0 + EOR R1, R1, R1 + EOR R2, R2, R2 + EOR R3, R3, R3 + EOR R4, R4, R4 + EOR R5, R5, R5 + EOR R6, R6, R6 + EOR R7, R7, R7 + MOVM.IA.W [R0-R7], (R12) + MOVM.IA [R0-R7], (R12) + MOVW 4(R13), g + RET diff --git a/vendor/golang.org/x/crypto/poly1305/sum_noasm.go b/vendor/golang.org/x/crypto/poly1305/sum_noasm.go new file mode 100644 index 00000000..751eec52 --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_noasm.go @@ -0,0 +1,14 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build s390x,!go1.11 !arm,!amd64,!s390x gccgo appengine nacl + +package poly1305 + +// Sum generates an authenticator for msg using a one-time key and puts the +// 16-byte result into out. Authenticating two different messages with the same +// key allows an attacker to forge messages at will. +func Sum(out *[TagSize]byte, msg []byte, key *[32]byte) { + sumGeneric(out, msg, key) +} diff --git a/vendor/golang.org/x/crypto/poly1305/sum_ref.go b/vendor/golang.org/x/crypto/poly1305/sum_ref.go new file mode 100644 index 00000000..c4d59bd0 --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_ref.go @@ -0,0 +1,139 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package poly1305 + +import "encoding/binary" + +// sumGeneric generates an authenticator for msg using a one-time key and +// puts the 16-byte result into out. This is the generic implementation of +// Sum and should be called if no assembly implementation is available. +func sumGeneric(out *[TagSize]byte, msg []byte, key *[32]byte) { + var ( + h0, h1, h2, h3, h4 uint32 // the hash accumulators + r0, r1, r2, r3, r4 uint64 // the r part of the key + ) + + r0 = uint64(binary.LittleEndian.Uint32(key[0:]) & 0x3ffffff) + r1 = uint64((binary.LittleEndian.Uint32(key[3:]) >> 2) & 0x3ffff03) + r2 = uint64((binary.LittleEndian.Uint32(key[6:]) >> 4) & 0x3ffc0ff) + r3 = uint64((binary.LittleEndian.Uint32(key[9:]) >> 6) & 0x3f03fff) + r4 = uint64((binary.LittleEndian.Uint32(key[12:]) >> 8) & 0x00fffff) + + R1, R2, R3, R4 := r1*5, r2*5, r3*5, r4*5 + + for len(msg) >= TagSize { + // h += msg + h0 += binary.LittleEndian.Uint32(msg[0:]) & 0x3ffffff + h1 += (binary.LittleEndian.Uint32(msg[3:]) >> 2) & 0x3ffffff + h2 += (binary.LittleEndian.Uint32(msg[6:]) >> 4) & 0x3ffffff + h3 += (binary.LittleEndian.Uint32(msg[9:]) >> 6) & 0x3ffffff + h4 += (binary.LittleEndian.Uint32(msg[12:]) >> 8) | (1 << 24) + + // h *= r + d0 := (uint64(h0) * r0) + (uint64(h1) * R4) + (uint64(h2) * R3) + (uint64(h3) * R2) + (uint64(h4) * R1) + d1 := (d0 >> 26) + (uint64(h0) * r1) + (uint64(h1) * r0) + (uint64(h2) * R4) + (uint64(h3) * R3) + (uint64(h4) * R2) + d2 := (d1 >> 26) + (uint64(h0) * r2) + (uint64(h1) * r1) + (uint64(h2) * r0) + (uint64(h3) * R4) + (uint64(h4) * R3) + d3 := (d2 >> 26) + (uint64(h0) * r3) + (uint64(h1) * r2) + (uint64(h2) * r1) + (uint64(h3) * r0) + (uint64(h4) * R4) + d4 := (d3 >> 26) + (uint64(h0) * r4) + (uint64(h1) * r3) + (uint64(h2) * r2) + (uint64(h3) * r1) + (uint64(h4) * r0) + + // h %= p + h0 = uint32(d0) & 0x3ffffff + h1 = uint32(d1) & 0x3ffffff + h2 = uint32(d2) & 0x3ffffff + h3 = uint32(d3) & 0x3ffffff + h4 = uint32(d4) & 0x3ffffff + + h0 += uint32(d4>>26) * 5 + h1 += h0 >> 26 + h0 = h0 & 0x3ffffff + + msg = msg[TagSize:] + } + + if len(msg) > 0 { + var block [TagSize]byte + off := copy(block[:], msg) + block[off] = 0x01 + + // h += msg + h0 += binary.LittleEndian.Uint32(block[0:]) & 0x3ffffff + h1 += (binary.LittleEndian.Uint32(block[3:]) >> 2) & 0x3ffffff + h2 += (binary.LittleEndian.Uint32(block[6:]) >> 4) & 0x3ffffff + h3 += (binary.LittleEndian.Uint32(block[9:]) >> 6) & 0x3ffffff + h4 += (binary.LittleEndian.Uint32(block[12:]) >> 8) + + // h *= r + d0 := (uint64(h0) * r0) + (uint64(h1) * R4) + (uint64(h2) * R3) + (uint64(h3) * R2) + (uint64(h4) * R1) + d1 := (d0 >> 26) + (uint64(h0) * r1) + (uint64(h1) * r0) + (uint64(h2) * R4) + (uint64(h3) * R3) + (uint64(h4) * R2) + d2 := (d1 >> 26) + (uint64(h0) * r2) + (uint64(h1) * r1) + (uint64(h2) * r0) + (uint64(h3) * R4) + (uint64(h4) * R3) + d3 := (d2 >> 26) + (uint64(h0) * r3) + (uint64(h1) * r2) + (uint64(h2) * r1) + (uint64(h3) * r0) + (uint64(h4) * R4) + d4 := (d3 >> 26) + (uint64(h0) * r4) + (uint64(h1) * r3) + (uint64(h2) * r2) + (uint64(h3) * r1) + (uint64(h4) * r0) + + // h %= p + h0 = uint32(d0) & 0x3ffffff + h1 = uint32(d1) & 0x3ffffff + h2 = uint32(d2) & 0x3ffffff + h3 = uint32(d3) & 0x3ffffff + h4 = uint32(d4) & 0x3ffffff + + h0 += uint32(d4>>26) * 5 + h1 += h0 >> 26 + h0 = h0 & 0x3ffffff + } + + // h %= p reduction + h2 += h1 >> 26 + h1 &= 0x3ffffff + h3 += h2 >> 26 + h2 &= 0x3ffffff + h4 += h3 >> 26 + h3 &= 0x3ffffff + h0 += 5 * (h4 >> 26) + h4 &= 0x3ffffff + h1 += h0 >> 26 + h0 &= 0x3ffffff + + // h - p + t0 := h0 + 5 + t1 := h1 + (t0 >> 26) + t2 := h2 + (t1 >> 26) + t3 := h3 + (t2 >> 26) + t4 := h4 + (t3 >> 26) - (1 << 26) + t0 &= 0x3ffffff + t1 &= 0x3ffffff + t2 &= 0x3ffffff + t3 &= 0x3ffffff + + // select h if h < p else h - p + t_mask := (t4 >> 31) - 1 + h_mask := ^t_mask + h0 = (h0 & h_mask) | (t0 & t_mask) + h1 = (h1 & h_mask) | (t1 & t_mask) + h2 = (h2 & h_mask) | (t2 & t_mask) + h3 = (h3 & h_mask) | (t3 & t_mask) + h4 = (h4 & h_mask) | (t4 & t_mask) + + // h %= 2^128 + h0 |= h1 << 26 + h1 = ((h1 >> 6) | (h2 << 20)) + h2 = ((h2 >> 12) | (h3 << 14)) + h3 = ((h3 >> 18) | (h4 << 8)) + + // s: the s part of the key + // tag = (h + s) % (2^128) + t := uint64(h0) + uint64(binary.LittleEndian.Uint32(key[16:])) + h0 = uint32(t) + t = uint64(h1) + uint64(binary.LittleEndian.Uint32(key[20:])) + (t >> 32) + h1 = uint32(t) + t = uint64(h2) + uint64(binary.LittleEndian.Uint32(key[24:])) + (t >> 32) + h2 = uint32(t) + t = uint64(h3) + uint64(binary.LittleEndian.Uint32(key[28:])) + (t >> 32) + h3 = uint32(t) + + binary.LittleEndian.PutUint32(out[0:], h0) + binary.LittleEndian.PutUint32(out[4:], h1) + binary.LittleEndian.PutUint32(out[8:], h2) + binary.LittleEndian.PutUint32(out[12:], h3) +} diff --git a/vendor/golang.org/x/crypto/poly1305/sum_s390x.go b/vendor/golang.org/x/crypto/poly1305/sum_s390x.go new file mode 100644 index 00000000..7a266cec --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_s390x.go @@ -0,0 +1,49 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build s390x,go1.11,!gccgo,!appengine + +package poly1305 + +// hasVectorFacility reports whether the machine supports +// the vector facility (vx). +func hasVectorFacility() bool + +// hasVMSLFacility reports whether the machine supports +// Vector Multiply Sum Logical (VMSL). +func hasVMSLFacility() bool + +var hasVX = hasVectorFacility() +var hasVMSL = hasVMSLFacility() + +// poly1305vx is an assembly implementation of Poly1305 that uses vector +// instructions. It must only be called if the vector facility (vx) is +// available. +//go:noescape +func poly1305vx(out *[16]byte, m *byte, mlen uint64, key *[32]byte) + +// poly1305vmsl is an assembly implementation of Poly1305 that uses vector +// instructions, including VMSL. It must only be called if the vector facility (vx) is +// available and if VMSL is supported. +//go:noescape +func poly1305vmsl(out *[16]byte, m *byte, mlen uint64, key *[32]byte) + +// Sum generates an authenticator for m using a one-time key and puts the +// 16-byte result into out. Authenticating two different messages with the same +// key allows an attacker to forge messages at will. +func Sum(out *[16]byte, m []byte, key *[32]byte) { + if hasVX { + var mPtr *byte + if len(m) > 0 { + mPtr = &m[0] + } + if hasVMSL && len(m) > 256 { + poly1305vmsl(out, mPtr, uint64(len(m)), key) + } else { + poly1305vx(out, mPtr, uint64(len(m)), key) + } + } else { + sumGeneric(out, m, key) + } +} diff --git a/vendor/golang.org/x/crypto/poly1305/sum_s390x.s b/vendor/golang.org/x/crypto/poly1305/sum_s390x.s new file mode 100644 index 00000000..356c07a6 --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_s390x.s @@ -0,0 +1,400 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build s390x,go1.11,!gccgo,!appengine + +#include "textflag.h" + +// Implementation of Poly1305 using the vector facility (vx). + +// constants +#define MOD26 V0 +#define EX0 V1 +#define EX1 V2 +#define EX2 V3 + +// temporaries +#define T_0 V4 +#define T_1 V5 +#define T_2 V6 +#define T_3 V7 +#define T_4 V8 + +// key (r) +#define R_0 V9 +#define R_1 V10 +#define R_2 V11 +#define R_3 V12 +#define R_4 V13 +#define R5_1 V14 +#define R5_2 V15 +#define R5_3 V16 +#define R5_4 V17 +#define RSAVE_0 R5 +#define RSAVE_1 R6 +#define RSAVE_2 R7 +#define RSAVE_3 R8 +#define RSAVE_4 R9 +#define R5SAVE_1 V28 +#define R5SAVE_2 V29 +#define R5SAVE_3 V30 +#define R5SAVE_4 V31 + +// message block +#define F_0 V18 +#define F_1 V19 +#define F_2 V20 +#define F_3 V21 +#define F_4 V22 + +// accumulator +#define H_0 V23 +#define H_1 V24 +#define H_2 V25 +#define H_3 V26 +#define H_4 V27 + +GLOBL ·keyMask<>(SB), RODATA, $16 +DATA ·keyMask<>+0(SB)/8, $0xffffff0ffcffff0f +DATA ·keyMask<>+8(SB)/8, $0xfcffff0ffcffff0f + +GLOBL ·bswapMask<>(SB), RODATA, $16 +DATA ·bswapMask<>+0(SB)/8, $0x0f0e0d0c0b0a0908 +DATA ·bswapMask<>+8(SB)/8, $0x0706050403020100 + +GLOBL ·constants<>(SB), RODATA, $64 +// MOD26 +DATA ·constants<>+0(SB)/8, $0x3ffffff +DATA ·constants<>+8(SB)/8, $0x3ffffff +// EX0 +DATA ·constants<>+16(SB)/8, $0x0006050403020100 +DATA ·constants<>+24(SB)/8, $0x1016151413121110 +// EX1 +DATA ·constants<>+32(SB)/8, $0x060c0b0a09080706 +DATA ·constants<>+40(SB)/8, $0x161c1b1a19181716 +// EX2 +DATA ·constants<>+48(SB)/8, $0x0d0d0d0d0d0f0e0d +DATA ·constants<>+56(SB)/8, $0x1d1d1d1d1d1f1e1d + +// h = (f*g) % (2**130-5) [partial reduction] +#define MULTIPLY(f0, f1, f2, f3, f4, g0, g1, g2, g3, g4, g51, g52, g53, g54, h0, h1, h2, h3, h4) \ + VMLOF f0, g0, h0 \ + VMLOF f0, g1, h1 \ + VMLOF f0, g2, h2 \ + VMLOF f0, g3, h3 \ + VMLOF f0, g4, h4 \ + VMLOF f1, g54, T_0 \ + VMLOF f1, g0, T_1 \ + VMLOF f1, g1, T_2 \ + VMLOF f1, g2, T_3 \ + VMLOF f1, g3, T_4 \ + VMALOF f2, g53, h0, h0 \ + VMALOF f2, g54, h1, h1 \ + VMALOF f2, g0, h2, h2 \ + VMALOF f2, g1, h3, h3 \ + VMALOF f2, g2, h4, h4 \ + VMALOF f3, g52, T_0, T_0 \ + VMALOF f3, g53, T_1, T_1 \ + VMALOF f3, g54, T_2, T_2 \ + VMALOF f3, g0, T_3, T_3 \ + VMALOF f3, g1, T_4, T_4 \ + VMALOF f4, g51, h0, h0 \ + VMALOF f4, g52, h1, h1 \ + VMALOF f4, g53, h2, h2 \ + VMALOF f4, g54, h3, h3 \ + VMALOF f4, g0, h4, h4 \ + VAG T_0, h0, h0 \ + VAG T_1, h1, h1 \ + VAG T_2, h2, h2 \ + VAG T_3, h3, h3 \ + VAG T_4, h4, h4 + +// carry h0->h1 h3->h4, h1->h2 h4->h0, h0->h1 h2->h3, h3->h4 +#define REDUCE(h0, h1, h2, h3, h4) \ + VESRLG $26, h0, T_0 \ + VESRLG $26, h3, T_1 \ + VN MOD26, h0, h0 \ + VN MOD26, h3, h3 \ + VAG T_0, h1, h1 \ + VAG T_1, h4, h4 \ + VESRLG $26, h1, T_2 \ + VESRLG $26, h4, T_3 \ + VN MOD26, h1, h1 \ + VN MOD26, h4, h4 \ + VESLG $2, T_3, T_4 \ + VAG T_3, T_4, T_4 \ + VAG T_2, h2, h2 \ + VAG T_4, h0, h0 \ + VESRLG $26, h2, T_0 \ + VESRLG $26, h0, T_1 \ + VN MOD26, h2, h2 \ + VN MOD26, h0, h0 \ + VAG T_0, h3, h3 \ + VAG T_1, h1, h1 \ + VESRLG $26, h3, T_2 \ + VN MOD26, h3, h3 \ + VAG T_2, h4, h4 + +// expand in0 into d[0] and in1 into d[1] +#define EXPAND(in0, in1, d0, d1, d2, d3, d4) \ + VGBM $0x0707, d1 \ // d1=tmp + VPERM in0, in1, EX2, d4 \ + VPERM in0, in1, EX0, d0 \ + VPERM in0, in1, EX1, d2 \ + VN d1, d4, d4 \ + VESRLG $26, d0, d1 \ + VESRLG $30, d2, d3 \ + VESRLG $4, d2, d2 \ + VN MOD26, d0, d0 \ + VN MOD26, d1, d1 \ + VN MOD26, d2, d2 \ + VN MOD26, d3, d3 + +// pack h4:h0 into h1:h0 (no carry) +#define PACK(h0, h1, h2, h3, h4) \ + VESLG $26, h1, h1 \ + VESLG $26, h3, h3 \ + VO h0, h1, h0 \ + VO h2, h3, h2 \ + VESLG $4, h2, h2 \ + VLEIB $7, $48, h1 \ + VSLB h1, h2, h2 \ + VO h0, h2, h0 \ + VLEIB $7, $104, h1 \ + VSLB h1, h4, h3 \ + VO h3, h0, h0 \ + VLEIB $7, $24, h1 \ + VSRLB h1, h4, h1 + +// if h > 2**130-5 then h -= 2**130-5 +#define MOD(h0, h1, t0, t1, t2) \ + VZERO t0 \ + VLEIG $1, $5, t0 \ + VACCQ h0, t0, t1 \ + VAQ h0, t0, t0 \ + VONE t2 \ + VLEIG $1, $-4, t2 \ + VAQ t2, t1, t1 \ + VACCQ h1, t1, t1 \ + VONE t2 \ + VAQ t2, t1, t1 \ + VN h0, t1, t2 \ + VNC t0, t1, t1 \ + VO t1, t2, h0 + +// func poly1305vx(out *[16]byte, m *byte, mlen uint64, key *[32]key) +TEXT ·poly1305vx(SB), $0-32 + // This code processes up to 2 blocks (32 bytes) per iteration + // using the algorithm described in: + // NEON crypto, Daniel J. Bernstein & Peter Schwabe + // https://cryptojedi.org/papers/neoncrypto-20120320.pdf + LMG out+0(FP), R1, R4 // R1=out, R2=m, R3=mlen, R4=key + + // load MOD26, EX0, EX1 and EX2 + MOVD $·constants<>(SB), R5 + VLM (R5), MOD26, EX2 + + // setup r + VL (R4), T_0 + MOVD $·keyMask<>(SB), R6 + VL (R6), T_1 + VN T_0, T_1, T_0 + EXPAND(T_0, T_0, R_0, R_1, R_2, R_3, R_4) + + // setup r*5 + VLEIG $0, $5, T_0 + VLEIG $1, $5, T_0 + + // store r (for final block) + VMLOF T_0, R_1, R5SAVE_1 + VMLOF T_0, R_2, R5SAVE_2 + VMLOF T_0, R_3, R5SAVE_3 + VMLOF T_0, R_4, R5SAVE_4 + VLGVG $0, R_0, RSAVE_0 + VLGVG $0, R_1, RSAVE_1 + VLGVG $0, R_2, RSAVE_2 + VLGVG $0, R_3, RSAVE_3 + VLGVG $0, R_4, RSAVE_4 + + // skip r**2 calculation + CMPBLE R3, $16, skip + + // calculate r**2 + MULTIPLY(R_0, R_1, R_2, R_3, R_4, R_0, R_1, R_2, R_3, R_4, R5SAVE_1, R5SAVE_2, R5SAVE_3, R5SAVE_4, H_0, H_1, H_2, H_3, H_4) + REDUCE(H_0, H_1, H_2, H_3, H_4) + VLEIG $0, $5, T_0 + VLEIG $1, $5, T_0 + VMLOF T_0, H_1, R5_1 + VMLOF T_0, H_2, R5_2 + VMLOF T_0, H_3, R5_3 + VMLOF T_0, H_4, R5_4 + VLR H_0, R_0 + VLR H_1, R_1 + VLR H_2, R_2 + VLR H_3, R_3 + VLR H_4, R_4 + + // initialize h + VZERO H_0 + VZERO H_1 + VZERO H_2 + VZERO H_3 + VZERO H_4 + +loop: + CMPBLE R3, $32, b2 + VLM (R2), T_0, T_1 + SUB $32, R3 + MOVD $32(R2), R2 + EXPAND(T_0, T_1, F_0, F_1, F_2, F_3, F_4) + VLEIB $4, $1, F_4 + VLEIB $12, $1, F_4 + +multiply: + VAG H_0, F_0, F_0 + VAG H_1, F_1, F_1 + VAG H_2, F_2, F_2 + VAG H_3, F_3, F_3 + VAG H_4, F_4, F_4 + MULTIPLY(F_0, F_1, F_2, F_3, F_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, H_0, H_1, H_2, H_3, H_4) + REDUCE(H_0, H_1, H_2, H_3, H_4) + CMPBNE R3, $0, loop + +finish: + // sum vectors + VZERO T_0 + VSUMQG H_0, T_0, H_0 + VSUMQG H_1, T_0, H_1 + VSUMQG H_2, T_0, H_2 + VSUMQG H_3, T_0, H_3 + VSUMQG H_4, T_0, H_4 + + // h may be >= 2*(2**130-5) so we need to reduce it again + REDUCE(H_0, H_1, H_2, H_3, H_4) + + // carry h1->h4 + VESRLG $26, H_1, T_1 + VN MOD26, H_1, H_1 + VAQ T_1, H_2, H_2 + VESRLG $26, H_2, T_2 + VN MOD26, H_2, H_2 + VAQ T_2, H_3, H_3 + VESRLG $26, H_3, T_3 + VN MOD26, H_3, H_3 + VAQ T_3, H_4, H_4 + + // h is now < 2*(2**130-5) + // pack h into h1 (hi) and h0 (lo) + PACK(H_0, H_1, H_2, H_3, H_4) + + // if h > 2**130-5 then h -= 2**130-5 + MOD(H_0, H_1, T_0, T_1, T_2) + + // h += s + MOVD $·bswapMask<>(SB), R5 + VL (R5), T_1 + VL 16(R4), T_0 + VPERM T_0, T_0, T_1, T_0 // reverse bytes (to big) + VAQ T_0, H_0, H_0 + VPERM H_0, H_0, T_1, H_0 // reverse bytes (to little) + VST H_0, (R1) + + RET + +b2: + CMPBLE R3, $16, b1 + + // 2 blocks remaining + SUB $17, R3 + VL (R2), T_0 + VLL R3, 16(R2), T_1 + ADD $1, R3 + MOVBZ $1, R0 + CMPBEQ R3, $16, 2(PC) + VLVGB R3, R0, T_1 + EXPAND(T_0, T_1, F_0, F_1, F_2, F_3, F_4) + CMPBNE R3, $16, 2(PC) + VLEIB $12, $1, F_4 + VLEIB $4, $1, F_4 + + // setup [r²,r] + VLVGG $1, RSAVE_0, R_0 + VLVGG $1, RSAVE_1, R_1 + VLVGG $1, RSAVE_2, R_2 + VLVGG $1, RSAVE_3, R_3 + VLVGG $1, RSAVE_4, R_4 + VPDI $0, R5_1, R5SAVE_1, R5_1 + VPDI $0, R5_2, R5SAVE_2, R5_2 + VPDI $0, R5_3, R5SAVE_3, R5_3 + VPDI $0, R5_4, R5SAVE_4, R5_4 + + MOVD $0, R3 + BR multiply + +skip: + VZERO H_0 + VZERO H_1 + VZERO H_2 + VZERO H_3 + VZERO H_4 + + CMPBEQ R3, $0, finish + +b1: + // 1 block remaining + SUB $1, R3 + VLL R3, (R2), T_0 + ADD $1, R3 + MOVBZ $1, R0 + CMPBEQ R3, $16, 2(PC) + VLVGB R3, R0, T_0 + VZERO T_1 + EXPAND(T_0, T_1, F_0, F_1, F_2, F_3, F_4) + CMPBNE R3, $16, 2(PC) + VLEIB $4, $1, F_4 + VLEIG $1, $1, R_0 + VZERO R_1 + VZERO R_2 + VZERO R_3 + VZERO R_4 + VZERO R5_1 + VZERO R5_2 + VZERO R5_3 + VZERO R5_4 + + // setup [r, 1] + VLVGG $0, RSAVE_0, R_0 + VLVGG $0, RSAVE_1, R_1 + VLVGG $0, RSAVE_2, R_2 + VLVGG $0, RSAVE_3, R_3 + VLVGG $0, RSAVE_4, R_4 + VPDI $0, R5SAVE_1, R5_1, R5_1 + VPDI $0, R5SAVE_2, R5_2, R5_2 + VPDI $0, R5SAVE_3, R5_3, R5_3 + VPDI $0, R5SAVE_4, R5_4, R5_4 + + MOVD $0, R3 + BR multiply + +TEXT ·hasVectorFacility(SB), NOSPLIT, $24-1 + MOVD $x-24(SP), R1 + XC $24, 0(R1), 0(R1) // clear the storage + MOVD $2, R0 // R0 is the number of double words stored -1 + WORD $0xB2B01000 // STFLE 0(R1) + XOR R0, R0 // reset the value of R0 + MOVBZ z-8(SP), R1 + AND $0x40, R1 + BEQ novector + +vectorinstalled: + // check if the vector instruction has been enabled + VLEIB $0, $0xF, V16 + VLGVB $0, V16, R1 + CMPBNE R1, $0xF, novector + MOVB $1, ret+0(FP) // have vx + RET + +novector: + MOVB $0, ret+0(FP) // no vx + RET diff --git a/vendor/golang.org/x/crypto/poly1305/sum_vmsl_s390x.s b/vendor/golang.org/x/crypto/poly1305/sum_vmsl_s390x.s new file mode 100644 index 00000000..e548020b --- /dev/null +++ b/vendor/golang.org/x/crypto/poly1305/sum_vmsl_s390x.s @@ -0,0 +1,931 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build s390x,go1.11,!gccgo,!appengine + +#include "textflag.h" + +// Implementation of Poly1305 using the vector facility (vx) and the VMSL instruction. + +// constants +#define EX0 V1 +#define EX1 V2 +#define EX2 V3 + +// temporaries +#define T_0 V4 +#define T_1 V5 +#define T_2 V6 +#define T_3 V7 +#define T_4 V8 +#define T_5 V9 +#define T_6 V10 +#define T_7 V11 +#define T_8 V12 +#define T_9 V13 +#define T_10 V14 + +// r**2 & r**4 +#define R_0 V15 +#define R_1 V16 +#define R_2 V17 +#define R5_1 V18 +#define R5_2 V19 +// key (r) +#define RSAVE_0 R7 +#define RSAVE_1 R8 +#define RSAVE_2 R9 +#define R5SAVE_1 R10 +#define R5SAVE_2 R11 + +// message block +#define M0 V20 +#define M1 V21 +#define M2 V22 +#define M3 V23 +#define M4 V24 +#define M5 V25 + +// accumulator +#define H0_0 V26 +#define H1_0 V27 +#define H2_0 V28 +#define H0_1 V29 +#define H1_1 V30 +#define H2_1 V31 + +GLOBL ·keyMask<>(SB), RODATA, $16 +DATA ·keyMask<>+0(SB)/8, $0xffffff0ffcffff0f +DATA ·keyMask<>+8(SB)/8, $0xfcffff0ffcffff0f + +GLOBL ·bswapMask<>(SB), RODATA, $16 +DATA ·bswapMask<>+0(SB)/8, $0x0f0e0d0c0b0a0908 +DATA ·bswapMask<>+8(SB)/8, $0x0706050403020100 + +GLOBL ·constants<>(SB), RODATA, $48 +// EX0 +DATA ·constants<>+0(SB)/8, $0x18191a1b1c1d1e1f +DATA ·constants<>+8(SB)/8, $0x0000050403020100 +// EX1 +DATA ·constants<>+16(SB)/8, $0x18191a1b1c1d1e1f +DATA ·constants<>+24(SB)/8, $0x00000a0908070605 +// EX2 +DATA ·constants<>+32(SB)/8, $0x18191a1b1c1d1e1f +DATA ·constants<>+40(SB)/8, $0x0000000f0e0d0c0b + +GLOBL ·c<>(SB), RODATA, $48 +// EX0 +DATA ·c<>+0(SB)/8, $0x0000050403020100 +DATA ·c<>+8(SB)/8, $0x0000151413121110 +// EX1 +DATA ·c<>+16(SB)/8, $0x00000a0908070605 +DATA ·c<>+24(SB)/8, $0x00001a1918171615 +// EX2 +DATA ·c<>+32(SB)/8, $0x0000000f0e0d0c0b +DATA ·c<>+40(SB)/8, $0x0000001f1e1d1c1b + +GLOBL ·reduce<>(SB), RODATA, $32 +// 44 bit +DATA ·reduce<>+0(SB)/8, $0x0 +DATA ·reduce<>+8(SB)/8, $0xfffffffffff +// 42 bit +DATA ·reduce<>+16(SB)/8, $0x0 +DATA ·reduce<>+24(SB)/8, $0x3ffffffffff + +// h = (f*g) % (2**130-5) [partial reduction] +// uses T_0...T_9 temporary registers +// input: m02_0, m02_1, m02_2, m13_0, m13_1, m13_2, r_0, r_1, r_2, r5_1, r5_2, m4_0, m4_1, m4_2, m5_0, m5_1, m5_2 +// temp: t0, t1, t2, t3, t4, t5, t6, t7, t8, t9 +// output: m02_0, m02_1, m02_2, m13_0, m13_1, m13_2 +#define MULTIPLY(m02_0, m02_1, m02_2, m13_0, m13_1, m13_2, r_0, r_1, r_2, r5_1, r5_2, m4_0, m4_1, m4_2, m5_0, m5_1, m5_2, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9) \ + \ // Eliminate the dependency for the last 2 VMSLs + VMSLG m02_0, r_2, m4_2, m4_2 \ + VMSLG m13_0, r_2, m5_2, m5_2 \ // 8 VMSLs pipelined + VMSLG m02_0, r_0, m4_0, m4_0 \ + VMSLG m02_1, r5_2, V0, T_0 \ + VMSLG m02_0, r_1, m4_1, m4_1 \ + VMSLG m02_1, r_0, V0, T_1 \ + VMSLG m02_1, r_1, V0, T_2 \ + VMSLG m02_2, r5_1, V0, T_3 \ + VMSLG m02_2, r5_2, V0, T_4 \ + VMSLG m13_0, r_0, m5_0, m5_0 \ + VMSLG m13_1, r5_2, V0, T_5 \ + VMSLG m13_0, r_1, m5_1, m5_1 \ + VMSLG m13_1, r_0, V0, T_6 \ + VMSLG m13_1, r_1, V0, T_7 \ + VMSLG m13_2, r5_1, V0, T_8 \ + VMSLG m13_2, r5_2, V0, T_9 \ + VMSLG m02_2, r_0, m4_2, m4_2 \ + VMSLG m13_2, r_0, m5_2, m5_2 \ + VAQ m4_0, T_0, m02_0 \ + VAQ m4_1, T_1, m02_1 \ + VAQ m5_0, T_5, m13_0 \ + VAQ m5_1, T_6, m13_1 \ + VAQ m02_0, T_3, m02_0 \ + VAQ m02_1, T_4, m02_1 \ + VAQ m13_0, T_8, m13_0 \ + VAQ m13_1, T_9, m13_1 \ + VAQ m4_2, T_2, m02_2 \ + VAQ m5_2, T_7, m13_2 \ + +// SQUARE uses three limbs of r and r_2*5 to output square of r +// uses T_1, T_5 and T_7 temporary registers +// input: r_0, r_1, r_2, r5_2 +// temp: TEMP0, TEMP1, TEMP2 +// output: p0, p1, p2 +#define SQUARE(r_0, r_1, r_2, r5_2, p0, p1, p2, TEMP0, TEMP1, TEMP2) \ + VMSLG r_0, r_0, p0, p0 \ + VMSLG r_1, r5_2, V0, TEMP0 \ + VMSLG r_2, r5_2, p1, p1 \ + VMSLG r_0, r_1, V0, TEMP1 \ + VMSLG r_1, r_1, p2, p2 \ + VMSLG r_0, r_2, V0, TEMP2 \ + VAQ TEMP0, p0, p0 \ + VAQ TEMP1, p1, p1 \ + VAQ TEMP2, p2, p2 \ + VAQ TEMP0, p0, p0 \ + VAQ TEMP1, p1, p1 \ + VAQ TEMP2, p2, p2 \ + +// carry h0->h1->h2->h0 || h3->h4->h5->h3 +// uses T_2, T_4, T_5, T_7, T_8, T_9 +// t6, t7, t8, t9, t10, t11 +// input: h0, h1, h2, h3, h4, h5 +// temp: t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11 +// output: h0, h1, h2, h3, h4, h5 +#define REDUCE(h0, h1, h2, h3, h4, h5, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \ + VLM (R12), t6, t7 \ // 44 and 42 bit clear mask + VLEIB $7, $0x28, t10 \ // 5 byte shift mask + VREPIB $4, t8 \ // 4 bit shift mask + VREPIB $2, t11 \ // 2 bit shift mask + VSRLB t10, h0, t0 \ // h0 byte shift + VSRLB t10, h1, t1 \ // h1 byte shift + VSRLB t10, h2, t2 \ // h2 byte shift + VSRLB t10, h3, t3 \ // h3 byte shift + VSRLB t10, h4, t4 \ // h4 byte shift + VSRLB t10, h5, t5 \ // h5 byte shift + VSRL t8, t0, t0 \ // h0 bit shift + VSRL t8, t1, t1 \ // h2 bit shift + VSRL t11, t2, t2 \ // h2 bit shift + VSRL t8, t3, t3 \ // h3 bit shift + VSRL t8, t4, t4 \ // h4 bit shift + VESLG $2, t2, t9 \ // h2 carry x5 + VSRL t11, t5, t5 \ // h5 bit shift + VN t6, h0, h0 \ // h0 clear carry + VAQ t2, t9, t2 \ // h2 carry x5 + VESLG $2, t5, t9 \ // h5 carry x5 + VN t6, h1, h1 \ // h1 clear carry + VN t7, h2, h2 \ // h2 clear carry + VAQ t5, t9, t5 \ // h5 carry x5 + VN t6, h3, h3 \ // h3 clear carry + VN t6, h4, h4 \ // h4 clear carry + VN t7, h5, h5 \ // h5 clear carry + VAQ t0, h1, h1 \ // h0->h1 + VAQ t3, h4, h4 \ // h3->h4 + VAQ t1, h2, h2 \ // h1->h2 + VAQ t4, h5, h5 \ // h4->h5 + VAQ t2, h0, h0 \ // h2->h0 + VAQ t5, h3, h3 \ // h5->h3 + VREPG $1, t6, t6 \ // 44 and 42 bit masks across both halves + VREPG $1, t7, t7 \ + VSLDB $8, h0, h0, h0 \ // set up [h0/1/2, h3/4/5] + VSLDB $8, h1, h1, h1 \ + VSLDB $8, h2, h2, h2 \ + VO h0, h3, h3 \ + VO h1, h4, h4 \ + VO h2, h5, h5 \ + VESRLG $44, h3, t0 \ // 44 bit shift right + VESRLG $44, h4, t1 \ + VESRLG $42, h5, t2 \ + VN t6, h3, h3 \ // clear carry bits + VN t6, h4, h4 \ + VN t7, h5, h5 \ + VESLG $2, t2, t9 \ // multiply carry by 5 + VAQ t9, t2, t2 \ + VAQ t0, h4, h4 \ + VAQ t1, h5, h5 \ + VAQ t2, h3, h3 \ + +// carry h0->h1->h2->h0 +// input: h0, h1, h2 +// temp: t0, t1, t2, t3, t4, t5, t6, t7, t8 +// output: h0, h1, h2 +#define REDUCE2(h0, h1, h2, t0, t1, t2, t3, t4, t5, t6, t7, t8) \ + VLEIB $7, $0x28, t3 \ // 5 byte shift mask + VREPIB $4, t4 \ // 4 bit shift mask + VREPIB $2, t7 \ // 2 bit shift mask + VGBM $0x003F, t5 \ // mask to clear carry bits + VSRLB t3, h0, t0 \ + VSRLB t3, h1, t1 \ + VSRLB t3, h2, t2 \ + VESRLG $4, t5, t5 \ // 44 bit clear mask + VSRL t4, t0, t0 \ + VSRL t4, t1, t1 \ + VSRL t7, t2, t2 \ + VESRLG $2, t5, t6 \ // 42 bit clear mask + VESLG $2, t2, t8 \ + VAQ t8, t2, t2 \ + VN t5, h0, h0 \ + VN t5, h1, h1 \ + VN t6, h2, h2 \ + VAQ t0, h1, h1 \ + VAQ t1, h2, h2 \ + VAQ t2, h0, h0 \ + VSRLB t3, h0, t0 \ + VSRLB t3, h1, t1 \ + VSRLB t3, h2, t2 \ + VSRL t4, t0, t0 \ + VSRL t4, t1, t1 \ + VSRL t7, t2, t2 \ + VN t5, h0, h0 \ + VN t5, h1, h1 \ + VESLG $2, t2, t8 \ + VN t6, h2, h2 \ + VAQ t0, h1, h1 \ + VAQ t8, t2, t2 \ + VAQ t1, h2, h2 \ + VAQ t2, h0, h0 \ + +// expands two message blocks into the lower halfs of the d registers +// moves the contents of the d registers into upper halfs +// input: in1, in2, d0, d1, d2, d3, d4, d5 +// temp: TEMP0, TEMP1, TEMP2, TEMP3 +// output: d0, d1, d2, d3, d4, d5 +#define EXPACC(in1, in2, d0, d1, d2, d3, d4, d5, TEMP0, TEMP1, TEMP2, TEMP3) \ + VGBM $0xff3f, TEMP0 \ + VGBM $0xff1f, TEMP1 \ + VESLG $4, d1, TEMP2 \ + VESLG $4, d4, TEMP3 \ + VESRLG $4, TEMP0, TEMP0 \ + VPERM in1, d0, EX0, d0 \ + VPERM in2, d3, EX0, d3 \ + VPERM in1, d2, EX2, d2 \ + VPERM in2, d5, EX2, d5 \ + VPERM in1, TEMP2, EX1, d1 \ + VPERM in2, TEMP3, EX1, d4 \ + VN TEMP0, d0, d0 \ + VN TEMP0, d3, d3 \ + VESRLG $4, d1, d1 \ + VESRLG $4, d4, d4 \ + VN TEMP1, d2, d2 \ + VN TEMP1, d5, d5 \ + VN TEMP0, d1, d1 \ + VN TEMP0, d4, d4 \ + +// expands one message block into the lower halfs of the d registers +// moves the contents of the d registers into upper halfs +// input: in, d0, d1, d2 +// temp: TEMP0, TEMP1, TEMP2 +// output: d0, d1, d2 +#define EXPACC2(in, d0, d1, d2, TEMP0, TEMP1, TEMP2) \ + VGBM $0xff3f, TEMP0 \ + VESLG $4, d1, TEMP2 \ + VGBM $0xff1f, TEMP1 \ + VPERM in, d0, EX0, d0 \ + VESRLG $4, TEMP0, TEMP0 \ + VPERM in, d2, EX2, d2 \ + VPERM in, TEMP2, EX1, d1 \ + VN TEMP0, d0, d0 \ + VN TEMP1, d2, d2 \ + VESRLG $4, d1, d1 \ + VN TEMP0, d1, d1 \ + +// pack h2:h0 into h1:h0 (no carry) +// input: h0, h1, h2 +// output: h0, h1, h2 +#define PACK(h0, h1, h2) \ + VMRLG h1, h2, h2 \ // copy h1 to upper half h2 + VESLG $44, h1, h1 \ // shift limb 1 44 bits, leaving 20 + VO h0, h1, h0 \ // combine h0 with 20 bits from limb 1 + VESRLG $20, h2, h1 \ // put top 24 bits of limb 1 into h1 + VLEIG $1, $0, h1 \ // clear h2 stuff from lower half of h1 + VO h0, h1, h0 \ // h0 now has 88 bits (limb 0 and 1) + VLEIG $0, $0, h2 \ // clear upper half of h2 + VESRLG $40, h2, h1 \ // h1 now has upper two bits of result + VLEIB $7, $88, h1 \ // for byte shift (11 bytes) + VSLB h1, h2, h2 \ // shift h2 11 bytes to the left + VO h0, h2, h0 \ // combine h0 with 20 bits from limb 1 + VLEIG $0, $0, h1 \ // clear upper half of h1 + +// if h > 2**130-5 then h -= 2**130-5 +// input: h0, h1 +// temp: t0, t1, t2 +// output: h0 +#define MOD(h0, h1, t0, t1, t2) \ + VZERO t0 \ + VLEIG $1, $5, t0 \ + VACCQ h0, t0, t1 \ + VAQ h0, t0, t0 \ + VONE t2 \ + VLEIG $1, $-4, t2 \ + VAQ t2, t1, t1 \ + VACCQ h1, t1, t1 \ + VONE t2 \ + VAQ t2, t1, t1 \ + VN h0, t1, t2 \ + VNC t0, t1, t1 \ + VO t1, t2, h0 \ + +// func poly1305vmsl(out *[16]byte, m *byte, mlen uint64, key *[32]key) +TEXT ·poly1305vmsl(SB), $0-32 + // This code processes 6 + up to 4 blocks (32 bytes) per iteration + // using the algorithm described in: + // NEON crypto, Daniel J. Bernstein & Peter Schwabe + // https://cryptojedi.org/papers/neoncrypto-20120320.pdf + // And as moddified for VMSL as described in + // Accelerating Poly1305 Cryptographic Message Authentication on the z14 + // O'Farrell et al, CASCON 2017, p48-55 + // https://ibm.ent.box.com/s/jf9gedj0e9d2vjctfyh186shaztavnht + + LMG out+0(FP), R1, R4 // R1=out, R2=m, R3=mlen, R4=key + VZERO V0 // c + + // load EX0, EX1 and EX2 + MOVD $·constants<>(SB), R5 + VLM (R5), EX0, EX2 // c + + // setup r + VL (R4), T_0 + MOVD $·keyMask<>(SB), R6 + VL (R6), T_1 + VN T_0, T_1, T_0 + VZERO T_2 // limbs for r + VZERO T_3 + VZERO T_4 + EXPACC2(T_0, T_2, T_3, T_4, T_1, T_5, T_7) + + // T_2, T_3, T_4: [0, r] + + // setup r*20 + VLEIG $0, $0, T_0 + VLEIG $1, $20, T_0 // T_0: [0, 20] + VZERO T_5 + VZERO T_6 + VMSLG T_0, T_3, T_5, T_5 + VMSLG T_0, T_4, T_6, T_6 + + // store r for final block in GR + VLGVG $1, T_2, RSAVE_0 // c + VLGVG $1, T_3, RSAVE_1 // c + VLGVG $1, T_4, RSAVE_2 // c + VLGVG $1, T_5, R5SAVE_1 // c + VLGVG $1, T_6, R5SAVE_2 // c + + // initialize h + VZERO H0_0 + VZERO H1_0 + VZERO H2_0 + VZERO H0_1 + VZERO H1_1 + VZERO H2_1 + + // initialize pointer for reduce constants + MOVD $·reduce<>(SB), R12 + + // calculate r**2 and 20*(r**2) + VZERO R_0 + VZERO R_1 + VZERO R_2 + SQUARE(T_2, T_3, T_4, T_6, R_0, R_1, R_2, T_1, T_5, T_7) + REDUCE2(R_0, R_1, R_2, M0, M1, M2, M3, M4, R5_1, R5_2, M5, T_1) + VZERO R5_1 + VZERO R5_2 + VMSLG T_0, R_1, R5_1, R5_1 + VMSLG T_0, R_2, R5_2, R5_2 + + // skip r**4 calculation if 3 blocks or less + CMPBLE R3, $48, b4 + + // calculate r**4 and 20*(r**4) + VZERO T_8 + VZERO T_9 + VZERO T_10 + SQUARE(R_0, R_1, R_2, R5_2, T_8, T_9, T_10, T_1, T_5, T_7) + REDUCE2(T_8, T_9, T_10, M0, M1, M2, M3, M4, T_2, T_3, M5, T_1) + VZERO T_2 + VZERO T_3 + VMSLG T_0, T_9, T_2, T_2 + VMSLG T_0, T_10, T_3, T_3 + + // put r**2 to the right and r**4 to the left of R_0, R_1, R_2 + VSLDB $8, T_8, T_8, T_8 + VSLDB $8, T_9, T_9, T_9 + VSLDB $8, T_10, T_10, T_10 + VSLDB $8, T_2, T_2, T_2 + VSLDB $8, T_3, T_3, T_3 + + VO T_8, R_0, R_0 + VO T_9, R_1, R_1 + VO T_10, R_2, R_2 + VO T_2, R5_1, R5_1 + VO T_3, R5_2, R5_2 + + CMPBLE R3, $80, load // less than or equal to 5 blocks in message + + // 6(or 5+1) blocks + SUB $81, R3 + VLM (R2), M0, M4 + VLL R3, 80(R2), M5 + ADD $1, R3 + MOVBZ $1, R0 + CMPBGE R3, $16, 2(PC) + VLVGB R3, R0, M5 + MOVD $96(R2), R2 + EXPACC(M0, M1, H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_0, T_1, T_2, T_3) + EXPACC(M2, M3, H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_0, T_1, T_2, T_3) + VLEIB $2, $1, H2_0 + VLEIB $2, $1, H2_1 + VLEIB $10, $1, H2_0 + VLEIB $10, $1, H2_1 + + VZERO M0 + VZERO M1 + VZERO M2 + VZERO M3 + VZERO T_4 + VZERO T_10 + EXPACC(M4, M5, M0, M1, M2, M3, T_4, T_10, T_0, T_1, T_2, T_3) + VLR T_4, M4 + VLEIB $10, $1, M2 + CMPBLT R3, $16, 2(PC) + VLEIB $10, $1, T_10 + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, T_10, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M2, M3, M4, T_4, T_5, T_2, T_7, T_8, T_9) + VMRHG V0, H0_1, H0_0 + VMRHG V0, H1_1, H1_0 + VMRHG V0, H2_1, H2_0 + VMRLG V0, H0_1, H0_1 + VMRLG V0, H1_1, H1_1 + VMRLG V0, H2_1, H2_1 + + SUB $16, R3 + CMPBLE R3, $0, square + +load: + // load EX0, EX1 and EX2 + MOVD $·c<>(SB), R5 + VLM (R5), EX0, EX2 + +loop: + CMPBLE R3, $64, add // b4 // last 4 or less blocks left + + // next 4 full blocks + VLM (R2), M2, M5 + SUB $64, R3 + MOVD $64(R2), R2 + REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, T_0, T_1, T_3, T_4, T_5, T_2, T_7, T_8, T_9) + + // expacc in-lined to create [m2, m3] limbs + VGBM $0x3f3f, T_0 // 44 bit clear mask + VGBM $0x1f1f, T_1 // 40 bit clear mask + VPERM M2, M3, EX0, T_3 + VESRLG $4, T_0, T_0 // 44 bit clear mask ready + VPERM M2, M3, EX1, T_4 + VPERM M2, M3, EX2, T_5 + VN T_0, T_3, T_3 + VESRLG $4, T_4, T_4 + VN T_1, T_5, T_5 + VN T_0, T_4, T_4 + VMRHG H0_1, T_3, H0_0 + VMRHG H1_1, T_4, H1_0 + VMRHG H2_1, T_5, H2_0 + VMRLG H0_1, T_3, H0_1 + VMRLG H1_1, T_4, H1_1 + VMRLG H2_1, T_5, H2_1 + VLEIB $10, $1, H2_0 + VLEIB $10, $1, H2_1 + VPERM M4, M5, EX0, T_3 + VPERM M4, M5, EX1, T_4 + VPERM M4, M5, EX2, T_5 + VN T_0, T_3, T_3 + VESRLG $4, T_4, T_4 + VN T_1, T_5, T_5 + VN T_0, T_4, T_4 + VMRHG V0, T_3, M0 + VMRHG V0, T_4, M1 + VMRHG V0, T_5, M2 + VMRLG V0, T_3, M3 + VMRLG V0, T_4, M4 + VMRLG V0, T_5, M5 + VLEIB $10, $1, M2 + VLEIB $10, $1, M5 + + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + CMPBNE R3, $0, loop + REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M3, M4, M5, T_4, T_5, T_2, T_7, T_8, T_9) + VMRHG V0, H0_1, H0_0 + VMRHG V0, H1_1, H1_0 + VMRHG V0, H2_1, H2_0 + VMRLG V0, H0_1, H0_1 + VMRLG V0, H1_1, H1_1 + VMRLG V0, H2_1, H2_1 + + // load EX0, EX1, EX2 + MOVD $·constants<>(SB), R5 + VLM (R5), EX0, EX2 + + // sum vectors + VAQ H0_0, H0_1, H0_0 + VAQ H1_0, H1_1, H1_0 + VAQ H2_0, H2_1, H2_0 + + // h may be >= 2*(2**130-5) so we need to reduce it again + // M0...M4 are used as temps here + REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, T_9, T_10, H0_1, M5) + +next: // carry h1->h2 + VLEIB $7, $0x28, T_1 + VREPIB $4, T_2 + VGBM $0x003F, T_3 + VESRLG $4, T_3 + + // byte shift + VSRLB T_1, H1_0, T_4 + + // bit shift + VSRL T_2, T_4, T_4 + + // clear h1 carry bits + VN T_3, H1_0, H1_0 + + // add carry + VAQ T_4, H2_0, H2_0 + + // h is now < 2*(2**130-5) + // pack h into h1 (hi) and h0 (lo) + PACK(H0_0, H1_0, H2_0) + + // if h > 2**130-5 then h -= 2**130-5 + MOD(H0_0, H1_0, T_0, T_1, T_2) + + // h += s + MOVD $·bswapMask<>(SB), R5 + VL (R5), T_1 + VL 16(R4), T_0 + VPERM T_0, T_0, T_1, T_0 // reverse bytes (to big) + VAQ T_0, H0_0, H0_0 + VPERM H0_0, H0_0, T_1, H0_0 // reverse bytes (to little) + VST H0_0, (R1) + RET + +add: + // load EX0, EX1, EX2 + MOVD $·constants<>(SB), R5 + VLM (R5), EX0, EX2 + + REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M3, M4, M5, T_4, T_5, T_2, T_7, T_8, T_9) + VMRHG V0, H0_1, H0_0 + VMRHG V0, H1_1, H1_0 + VMRHG V0, H2_1, H2_0 + VMRLG V0, H0_1, H0_1 + VMRLG V0, H1_1, H1_1 + VMRLG V0, H2_1, H2_1 + CMPBLE R3, $64, b4 + +b4: + CMPBLE R3, $48, b3 // 3 blocks or less + + // 4(3+1) blocks remaining + SUB $49, R3 + VLM (R2), M0, M2 + VLL R3, 48(R2), M3 + ADD $1, R3 + MOVBZ $1, R0 + CMPBEQ R3, $16, 2(PC) + VLVGB R3, R0, M3 + MOVD $64(R2), R2 + EXPACC(M0, M1, H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_0, T_1, T_2, T_3) + VLEIB $10, $1, H2_0 + VLEIB $10, $1, H2_1 + VZERO M0 + VZERO M1 + VZERO M4 + VZERO M5 + VZERO T_4 + VZERO T_10 + EXPACC(M2, M3, M0, M1, M4, M5, T_4, T_10, T_0, T_1, T_2, T_3) + VLR T_4, M2 + VLEIB $10, $1, M4 + CMPBNE R3, $16, 2(PC) + VLEIB $10, $1, T_10 + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M4, M5, M2, T_10, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M3, M4, M5, T_4, T_5, T_2, T_7, T_8, T_9) + VMRHG V0, H0_1, H0_0 + VMRHG V0, H1_1, H1_0 + VMRHG V0, H2_1, H2_0 + VMRLG V0, H0_1, H0_1 + VMRLG V0, H1_1, H1_1 + VMRLG V0, H2_1, H2_1 + SUB $16, R3 + CMPBLE R3, $0, square // this condition must always hold true! + +b3: + CMPBLE R3, $32, b2 + + // 3 blocks remaining + + // setup [r²,r] + VSLDB $8, R_0, R_0, R_0 + VSLDB $8, R_1, R_1, R_1 + VSLDB $8, R_2, R_2, R_2 + VSLDB $8, R5_1, R5_1, R5_1 + VSLDB $8, R5_2, R5_2, R5_2 + + VLVGG $1, RSAVE_0, R_0 + VLVGG $1, RSAVE_1, R_1 + VLVGG $1, RSAVE_2, R_2 + VLVGG $1, R5SAVE_1, R5_1 + VLVGG $1, R5SAVE_2, R5_2 + + // setup [h0, h1] + VSLDB $8, H0_0, H0_0, H0_0 + VSLDB $8, H1_0, H1_0, H1_0 + VSLDB $8, H2_0, H2_0, H2_0 + VO H0_1, H0_0, H0_0 + VO H1_1, H1_0, H1_0 + VO H2_1, H2_0, H2_0 + VZERO H0_1 + VZERO H1_1 + VZERO H2_1 + + VZERO M0 + VZERO M1 + VZERO M2 + VZERO M3 + VZERO M4 + VZERO M5 + + // H*[r**2, r] + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, H0_1, H1_1, T_10, M5) + + SUB $33, R3 + VLM (R2), M0, M1 + VLL R3, 32(R2), M2 + ADD $1, R3 + MOVBZ $1, R0 + CMPBEQ R3, $16, 2(PC) + VLVGB R3, R0, M2 + + // H += m0 + VZERO T_1 + VZERO T_2 + VZERO T_3 + EXPACC2(M0, T_1, T_2, T_3, T_4, T_5, T_6) + VLEIB $10, $1, T_3 + VAG H0_0, T_1, H0_0 + VAG H1_0, T_2, H1_0 + VAG H2_0, T_3, H2_0 + + VZERO M0 + VZERO M3 + VZERO M4 + VZERO M5 + VZERO T_10 + + // (H+m0)*r + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M3, M4, M5, V0, T_10, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE2(H0_0, H1_0, H2_0, M0, M3, M4, M5, T_10, H0_1, H1_1, H2_1, T_9) + + // H += m1 + VZERO V0 + VZERO T_1 + VZERO T_2 + VZERO T_3 + EXPACC2(M1, T_1, T_2, T_3, T_4, T_5, T_6) + VLEIB $10, $1, T_3 + VAQ H0_0, T_1, H0_0 + VAQ H1_0, T_2, H1_0 + VAQ H2_0, T_3, H2_0 + REDUCE2(H0_0, H1_0, H2_0, M0, M3, M4, M5, T_9, H0_1, H1_1, H2_1, T_10) + + // [H, m2] * [r**2, r] + EXPACC2(M2, H0_0, H1_0, H2_0, T_1, T_2, T_3) + CMPBNE R3, $16, 2(PC) + VLEIB $10, $1, H2_0 + VZERO M0 + VZERO M1 + VZERO M2 + VZERO M3 + VZERO M4 + VZERO M5 + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, H0_1, H1_1, M5, T_10) + SUB $16, R3 + CMPBLE R3, $0, next // this condition must always hold true! + +b2: + CMPBLE R3, $16, b1 + + // 2 blocks remaining + + // setup [r²,r] + VSLDB $8, R_0, R_0, R_0 + VSLDB $8, R_1, R_1, R_1 + VSLDB $8, R_2, R_2, R_2 + VSLDB $8, R5_1, R5_1, R5_1 + VSLDB $8, R5_2, R5_2, R5_2 + + VLVGG $1, RSAVE_0, R_0 + VLVGG $1, RSAVE_1, R_1 + VLVGG $1, RSAVE_2, R_2 + VLVGG $1, R5SAVE_1, R5_1 + VLVGG $1, R5SAVE_2, R5_2 + + // setup [h0, h1] + VSLDB $8, H0_0, H0_0, H0_0 + VSLDB $8, H1_0, H1_0, H1_0 + VSLDB $8, H2_0, H2_0, H2_0 + VO H0_1, H0_0, H0_0 + VO H1_1, H1_0, H1_0 + VO H2_1, H2_0, H2_0 + VZERO H0_1 + VZERO H1_1 + VZERO H2_1 + + VZERO M0 + VZERO M1 + VZERO M2 + VZERO M3 + VZERO M4 + VZERO M5 + + // H*[r**2, r] + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M2, M3, M4, T_4, T_5, T_2, T_7, T_8, T_9) + VMRHG V0, H0_1, H0_0 + VMRHG V0, H1_1, H1_0 + VMRHG V0, H2_1, H2_0 + VMRLG V0, H0_1, H0_1 + VMRLG V0, H1_1, H1_1 + VMRLG V0, H2_1, H2_1 + + // move h to the left and 0s at the right + VSLDB $8, H0_0, H0_0, H0_0 + VSLDB $8, H1_0, H1_0, H1_0 + VSLDB $8, H2_0, H2_0, H2_0 + + // get message blocks and append 1 to start + SUB $17, R3 + VL (R2), M0 + VLL R3, 16(R2), M1 + ADD $1, R3 + MOVBZ $1, R0 + CMPBEQ R3, $16, 2(PC) + VLVGB R3, R0, M1 + VZERO T_6 + VZERO T_7 + VZERO T_8 + EXPACC2(M0, T_6, T_7, T_8, T_1, T_2, T_3) + EXPACC2(M1, T_6, T_7, T_8, T_1, T_2, T_3) + VLEIB $2, $1, T_8 + CMPBNE R3, $16, 2(PC) + VLEIB $10, $1, T_8 + + // add [m0, m1] to h + VAG H0_0, T_6, H0_0 + VAG H1_0, T_7, H1_0 + VAG H2_0, T_8, H2_0 + + VZERO M2 + VZERO M3 + VZERO M4 + VZERO M5 + VZERO T_10 + VZERO M0 + + // at this point R_0 .. R5_2 look like [r**2, r] + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M2, M3, M4, M5, T_10, M0, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE2(H0_0, H1_0, H2_0, M2, M3, M4, M5, T_9, H0_1, H1_1, H2_1, T_10) + SUB $16, R3, R3 + CMPBLE R3, $0, next + +b1: + CMPBLE R3, $0, next + + // 1 block remaining + + // setup [r²,r] + VSLDB $8, R_0, R_0, R_0 + VSLDB $8, R_1, R_1, R_1 + VSLDB $8, R_2, R_2, R_2 + VSLDB $8, R5_1, R5_1, R5_1 + VSLDB $8, R5_2, R5_2, R5_2 + + VLVGG $1, RSAVE_0, R_0 + VLVGG $1, RSAVE_1, R_1 + VLVGG $1, RSAVE_2, R_2 + VLVGG $1, R5SAVE_1, R5_1 + VLVGG $1, R5SAVE_2, R5_2 + + // setup [h0, h1] + VSLDB $8, H0_0, H0_0, H0_0 + VSLDB $8, H1_0, H1_0, H1_0 + VSLDB $8, H2_0, H2_0, H2_0 + VO H0_1, H0_0, H0_0 + VO H1_1, H1_0, H1_0 + VO H2_1, H2_0, H2_0 + VZERO H0_1 + VZERO H1_1 + VZERO H2_1 + + VZERO M0 + VZERO M1 + VZERO M2 + VZERO M3 + VZERO M4 + VZERO M5 + + // H*[r**2, r] + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, T_9, T_10, H0_1, M5) + + // set up [0, m0] limbs + SUB $1, R3 + VLL R3, (R2), M0 + ADD $1, R3 + MOVBZ $1, R0 + CMPBEQ R3, $16, 2(PC) + VLVGB R3, R0, M0 + VZERO T_1 + VZERO T_2 + VZERO T_3 + EXPACC2(M0, T_1, T_2, T_3, T_4, T_5, T_6)// limbs: [0, m] + CMPBNE R3, $16, 2(PC) + VLEIB $10, $1, T_3 + + // h+m0 + VAQ H0_0, T_1, H0_0 + VAQ H1_0, T_2, H1_0 + VAQ H2_0, T_3, H2_0 + + VZERO M0 + VZERO M1 + VZERO M2 + VZERO M3 + VZERO M4 + VZERO M5 + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, T_9, T_10, H0_1, M5) + + BR next + +square: + // setup [r²,r] + VSLDB $8, R_0, R_0, R_0 + VSLDB $8, R_1, R_1, R_1 + VSLDB $8, R_2, R_2, R_2 + VSLDB $8, R5_1, R5_1, R5_1 + VSLDB $8, R5_2, R5_2, R5_2 + + VLVGG $1, RSAVE_0, R_0 + VLVGG $1, RSAVE_1, R_1 + VLVGG $1, RSAVE_2, R_2 + VLVGG $1, R5SAVE_1, R5_1 + VLVGG $1, R5SAVE_2, R5_2 + + // setup [h0, h1] + VSLDB $8, H0_0, H0_0, H0_0 + VSLDB $8, H1_0, H1_0, H1_0 + VSLDB $8, H2_0, H2_0, H2_0 + VO H0_1, H0_0, H0_0 + VO H1_1, H1_0, H1_0 + VO H2_1, H2_0, H2_0 + VZERO H0_1 + VZERO H1_1 + VZERO H2_1 + + VZERO M0 + VZERO M1 + VZERO M2 + VZERO M3 + VZERO M4 + VZERO M5 + + // (h0*r**2) + (h1*r) + MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9) + REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, T_9, T_10, H0_1, M5) + BR next + +TEXT ·hasVMSLFacility(SB), NOSPLIT, $24-1 + MOVD $x-24(SP), R1 + XC $24, 0(R1), 0(R1) // clear the storage + MOVD $2, R0 // R0 is the number of double words stored -1 + WORD $0xB2B01000 // STFLE 0(R1) + XOR R0, R0 // reset the value of R0 + MOVBZ z-8(SP), R1 + AND $0x01, R1 + BEQ novmsl + +vectorinstalled: + // check if the vector instruction has been enabled + VLEIB $0, $0xF, V16 + VLGVB $0, V16, R1 + CMPBNE R1, $0xF, novmsl + MOVB $1, ret+0(FP) // have vx + RET + +novmsl: + MOVB $0, ret+0(FP) // no vx + RET diff --git a/vendor/golang.org/x/sys/LICENSE b/vendor/golang.org/x/sys/LICENSE new file mode 100644 index 00000000..6a66aea5 --- /dev/null +++ b/vendor/golang.org/x/sys/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/sys/PATENTS b/vendor/golang.org/x/sys/PATENTS new file mode 100644 index 00000000..73309904 --- /dev/null +++ b/vendor/golang.org/x/sys/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/sys/cpu/cpu.go b/vendor/golang.org/x/sys/cpu/cpu.go new file mode 100644 index 00000000..3d88f866 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu.go @@ -0,0 +1,38 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cpu implements processor feature detection for +// various CPU architectures. +package cpu + +// CacheLinePad is used to pad structs to avoid false sharing. +type CacheLinePad struct{ _ [cacheLineSize]byte } + +// X86 contains the supported CPU features of the +// current X86/AMD64 platform. If the current platform +// is not X86/AMD64 then all feature flags are false. +// +// X86 is padded to avoid false sharing. Further the HasAVX +// and HasAVX2 are only set if the OS supports XMM and YMM +// registers in addition to the CPUID feature bit being set. +var X86 struct { + _ CacheLinePad + HasAES bool // AES hardware implementation (AES NI) + HasADX bool // Multi-precision add-carry instruction extensions + HasAVX bool // Advanced vector extension + HasAVX2 bool // Advanced vector extension 2 + HasBMI1 bool // Bit manipulation instruction set 1 + HasBMI2 bool // Bit manipulation instruction set 2 + HasERMS bool // Enhanced REP for MOVSB and STOSB + HasFMA bool // Fused-multiply-add instructions + HasOSXSAVE bool // OS supports XSAVE/XRESTOR for saving/restoring XMM registers. + HasPCLMULQDQ bool // PCLMULQDQ instruction - most often used for AES-GCM + HasPOPCNT bool // Hamming weight instruction POPCNT. + HasSSE2 bool // Streaming SIMD extension 2 (always available on amd64) + HasSSE3 bool // Streaming SIMD extension 3 + HasSSSE3 bool // Supplemental streaming SIMD extension 3 + HasSSE41 bool // Streaming SIMD extension 4 and 4.1 + HasSSE42 bool // Streaming SIMD extension 4 and 4.2 + _ CacheLinePad +} diff --git a/vendor/golang.org/x/sys/cpu/cpu_arm.go b/vendor/golang.org/x/sys/cpu/cpu_arm.go new file mode 100644 index 00000000..d93036f7 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_arm.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const cacheLineSize = 32 diff --git a/vendor/golang.org/x/sys/cpu/cpu_arm64.go b/vendor/golang.org/x/sys/cpu/cpu_arm64.go new file mode 100644 index 00000000..1d2ab290 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_arm64.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const cacheLineSize = 64 diff --git a/vendor/golang.org/x/sys/cpu/cpu_gc_x86.go b/vendor/golang.org/x/sys/cpu/cpu_gc_x86.go new file mode 100644 index 00000000..f7cb4697 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_gc_x86.go @@ -0,0 +1,16 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 +// +build !gccgo + +package cpu + +// cpuid is implemented in cpu_x86.s for gc compiler +// and in cpu_gccgo.c for gccgo. +func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) + +// xgetbv with ecx = 0 is implemented in cpu_x86.s for gc compiler +// and in cpu_gccgo.c for gccgo. +func xgetbv() (eax, edx uint32) diff --git a/vendor/golang.org/x/sys/cpu/cpu_gccgo.c b/vendor/golang.org/x/sys/cpu/cpu_gccgo.c new file mode 100644 index 00000000..e363c7d1 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_gccgo.c @@ -0,0 +1,43 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 +// +build gccgo + +#include +#include + +// Need to wrap __get_cpuid_count because it's declared as static. +int +gccgoGetCpuidCount(uint32_t leaf, uint32_t subleaf, + uint32_t *eax, uint32_t *ebx, + uint32_t *ecx, uint32_t *edx) +{ + return __get_cpuid_count(leaf, subleaf, eax, ebx, ecx, edx); +} + +// xgetbv reads the contents of an XCR (Extended Control Register) +// specified in the ECX register into registers EDX:EAX. +// Currently, the only supported value for XCR is 0. +// +// TODO: Replace with a better alternative: +// +// #include +// +// #pragma GCC target("xsave") +// +// void gccgoXgetbv(uint32_t *eax, uint32_t *edx) { +// unsigned long long x = _xgetbv(0); +// *eax = x & 0xffffffff; +// *edx = (x >> 32) & 0xffffffff; +// } +// +// Note that _xgetbv is defined starting with GCC 8. +void +gccgoXgetbv(uint32_t *eax, uint32_t *edx) +{ + __asm(" xorl %%ecx, %%ecx\n" + " xgetbv" + : "=a"(*eax), "=d"(*edx)); +} diff --git a/vendor/golang.org/x/sys/cpu/cpu_gccgo.go b/vendor/golang.org/x/sys/cpu/cpu_gccgo.go new file mode 100644 index 00000000..ba49b91b --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_gccgo.go @@ -0,0 +1,26 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 +// +build gccgo + +package cpu + +//extern gccgoGetCpuidCount +func gccgoGetCpuidCount(eaxArg, ecxArg uint32, eax, ebx, ecx, edx *uint32) + +func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32) { + var a, b, c, d uint32 + gccgoGetCpuidCount(eaxArg, ecxArg, &a, &b, &c, &d) + return a, b, c, d +} + +//extern gccgoXgetbv +func gccgoXgetbv(eax, edx *uint32) + +func xgetbv() (eax, edx uint32) { + var a, d uint32 + gccgoXgetbv(&a, &d) + return a, d +} diff --git a/vendor/golang.org/x/sys/cpu/cpu_mips64x.go b/vendor/golang.org/x/sys/cpu/cpu_mips64x.go new file mode 100644 index 00000000..6165f121 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_mips64x.go @@ -0,0 +1,9 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build mips64 mips64le + +package cpu + +const cacheLineSize = 32 diff --git a/vendor/golang.org/x/sys/cpu/cpu_mipsx.go b/vendor/golang.org/x/sys/cpu/cpu_mipsx.go new file mode 100644 index 00000000..1269eee8 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_mipsx.go @@ -0,0 +1,9 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build mips mipsle + +package cpu + +const cacheLineSize = 32 diff --git a/vendor/golang.org/x/sys/cpu/cpu_ppc64x.go b/vendor/golang.org/x/sys/cpu/cpu_ppc64x.go new file mode 100644 index 00000000..d10759a5 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_ppc64x.go @@ -0,0 +1,9 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ppc64 ppc64le + +package cpu + +const cacheLineSize = 128 diff --git a/vendor/golang.org/x/sys/cpu/cpu_s390x.go b/vendor/golang.org/x/sys/cpu/cpu_s390x.go new file mode 100644 index 00000000..684c4f00 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_s390x.go @@ -0,0 +1,7 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cpu + +const cacheLineSize = 256 diff --git a/vendor/golang.org/x/sys/cpu/cpu_x86.go b/vendor/golang.org/x/sys/cpu/cpu_x86.go new file mode 100644 index 00000000..71e288b0 --- /dev/null +++ b/vendor/golang.org/x/sys/cpu/cpu_x86.go @@ -0,0 +1,55 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build 386 amd64 amd64p32 + +package cpu + +const cacheLineSize = 64 + +func init() { + maxID, _, _, _ := cpuid(0, 0) + + if maxID < 1 { + return + } + + _, _, ecx1, edx1 := cpuid(1, 0) + X86.HasSSE2 = isSet(26, edx1) + + X86.HasSSE3 = isSet(0, ecx1) + X86.HasPCLMULQDQ = isSet(1, ecx1) + X86.HasSSSE3 = isSet(9, ecx1) + X86.HasFMA = isSet(12, ecx1) + X86.HasSSE41 = isSet(19, ecx1) + X86.HasSSE42 = isSet(20, ecx1) + X86.HasPOPCNT = isSet(23, ecx1) + X86.HasAES = isSet(25, ecx1) + X86.HasOSXSAVE = isSet(27, ecx1) + + osSupportsAVX := false + // For XGETBV, OSXSAVE bit is required and sufficient. + if X86.HasOSXSAVE { + eax, _ := xgetbv() + // Check if XMM and YMM registers have OS support. + osSupportsAVX = isSet(1, eax) && isSet(2, eax) + } + + X86.HasAVX = isSet(28, ecx1) && osSupportsAVX + + if maxID < 7 { + return + } + + _, ebx7, _, _ := cpuid(7, 0) + X86.HasBMI1 = isSet(3, ebx7) + X86.HasAVX2 = isSet(5, ebx7) && osSupportsAVX + X86.HasBMI2 = isSet(8, ebx7) + X86.HasERMS = isSet(9, ebx7) + X86.HasADX = isSet(19, ebx7) +} + +func isSet(bitpos uint, value uint32) bool { + return value&(1<