From 5977de660b4719dd92b35d297744a1e4bd566dba Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Thu, 21 Jul 2022 18:50:10 +0200 Subject: [PATCH] Add suport for DoH over HTTP/3 --- dnscrypt-proxy/xtransport.go | 53 +- go.mod | 13 +- go.sum | 129 +- vendor/github.com/cheekybits/genny/.gitignore | 26 + .../github.com/cheekybits/genny/.travis.yml | 6 + vendor/github.com/cheekybits/genny/LICENSE | 22 + vendor/github.com/cheekybits/genny/README.md | 245 ++ vendor/github.com/cheekybits/genny/doc.go | 2 + .../cheekybits/genny/generic/doc.go | 2 + .../cheekybits/genny/generic/generic.go | 13 + vendor/github.com/cheekybits/genny/main.go | 154 ++ .../cheekybits/genny/out/lazy_file.go | 38 + .../cheekybits/genny/parse/builtins.go | 41 + .../github.com/cheekybits/genny/parse/doc.go | 14 + .../cheekybits/genny/parse/errors.go | 47 + .../cheekybits/genny/parse/parse.go | 298 +++ .../cheekybits/genny/parse/typesets.go | 89 + .../go-task/slim-sprig/.editorconfig | 14 + .../go-task/slim-sprig/.gitattributes | 1 + .../github.com/go-task/slim-sprig/.gitignore | 2 + .../go-task/slim-sprig/CHANGELOG.md | 364 +++ .../github.com/go-task/slim-sprig/LICENSE.txt | 19 + .../github.com/go-task/slim-sprig/README.md | 73 + .../go-task/slim-sprig/Taskfile.yml | 12 + .../github.com/go-task/slim-sprig/crypto.go | 24 + vendor/github.com/go-task/slim-sprig/date.go | 152 ++ .../github.com/go-task/slim-sprig/defaults.go | 163 ++ vendor/github.com/go-task/slim-sprig/dict.go | 118 + vendor/github.com/go-task/slim-sprig/doc.go | 19 + .../go-task/slim-sprig/functions.go | 317 +++ vendor/github.com/go-task/slim-sprig/list.go | 464 ++++ .../github.com/go-task/slim-sprig/network.go | 12 + .../github.com/go-task/slim-sprig/numeric.go | 228 ++ .../github.com/go-task/slim-sprig/reflect.go | 28 + vendor/github.com/go-task/slim-sprig/regex.go | 83 + .../github.com/go-task/slim-sprig/strings.go | 189 ++ vendor/github.com/go-task/slim-sprig/url.go | 66 + .../lucas-clemente/quic-go/.gitignore | 17 + .../lucas-clemente/quic-go/.golangci.yml | 44 + .../lucas-clemente/quic-go/Changelog.md | 109 + .../github.com/lucas-clemente/quic-go/LICENSE | 21 + .../lucas-clemente/quic-go/README.md | 61 + .../lucas-clemente/quic-go/buffer_pool.go | 80 + .../lucas-clemente/quic-go/client.go | 339 +++ .../lucas-clemente/quic-go/closed_conn.go | 112 + .../lucas-clemente/quic-go/codecov.yml | 21 + .../lucas-clemente/quic-go/config.go | 124 + .../quic-go/conn_id_generator.go | 140 ++ .../lucas-clemente/quic-go/conn_id_manager.go | 207 ++ .../lucas-clemente/quic-go/connection.go | 2006 +++++++++++++++++ .../lucas-clemente/quic-go/crypto_stream.go | 115 + .../quic-go/crypto_stream_manager.go | 61 + .../lucas-clemente/quic-go/datagram_queue.go | 87 + .../lucas-clemente/quic-go/errors.go | 58 + .../lucas-clemente/quic-go/frame_sorter.go | 224 ++ .../lucas-clemente/quic-go/framer.go | 171 ++ .../lucas-clemente/quic-go/http3/body.go | 130 ++ .../lucas-clemente/quic-go/http3/client.go | 423 ++++ .../quic-go/http3/error_codes.go | 73 + .../lucas-clemente/quic-go/http3/frames.go | 164 ++ .../quic-go/http3/gzip_reader.go | 39 + .../quic-go/http3/http_stream.go | 71 + .../lucas-clemente/quic-go/http3/request.go | 113 + .../quic-go/http3/request_writer.go | 283 +++ .../quic-go/http3/response_writer.go | 118 + .../lucas-clemente/quic-go/http3/roundtrip.go | 221 ++ .../lucas-clemente/quic-go/http3/server.go | 742 ++++++ .../lucas-clemente/quic-go/interface.go | 328 +++ .../internal/ackhandler/ack_eliciting.go | 20 + .../quic-go/internal/ackhandler/ackhandler.go | 21 + .../quic-go/internal/ackhandler/frame.go | 9 + .../quic-go/internal/ackhandler/gen.go | 3 + .../quic-go/internal/ackhandler/interfaces.go | 68 + .../quic-go/internal/ackhandler/mockgen.go | 3 + .../internal/ackhandler/packet_linkedlist.go | 217 ++ .../ackhandler/packet_number_generator.go | 76 + .../ackhandler/received_packet_handler.go | 136 ++ .../ackhandler/received_packet_history.go | 142 ++ .../ackhandler/received_packet_tracker.go | 196 ++ .../quic-go/internal/ackhandler/send_mode.go | 40 + .../ackhandler/sent_packet_handler.go | 838 +++++++ .../ackhandler/sent_packet_history.go | 108 + .../quic-go/internal/congestion/bandwidth.go | 25 + .../quic-go/internal/congestion/clock.go | 18 + .../quic-go/internal/congestion/cubic.go | 214 ++ .../internal/congestion/cubic_sender.go | 316 +++ .../internal/congestion/hybrid_slow_start.go | 113 + .../quic-go/internal/congestion/interface.go | 28 + .../quic-go/internal/congestion/pacer.go | 77 + .../flowcontrol/base_flow_controller.go | 125 + .../flowcontrol/connection_flow_controller.go | 112 + .../quic-go/internal/flowcontrol/interface.go | 42 + .../flowcontrol/stream_flow_controller.go | 149 ++ .../quic-go/internal/handshake/aead.go | 161 ++ .../internal/handshake/crypto_setup.go | 819 +++++++ .../internal/handshake/header_protector.go | 136 ++ .../quic-go/internal/handshake/hkdf.go | 29 + .../internal/handshake/initial_aead.go | 81 + .../quic-go/internal/handshake/interface.go | 102 + .../quic-go/internal/handshake/mockgen.go | 3 + .../quic-go/internal/handshake/retry.go | 62 + .../internal/handshake/session_ticket.go | 48 + .../handshake/tls_extension_handler.go | 68 + .../internal/handshake/token_generator.go | 134 ++ .../internal/handshake/token_protector.go | 89 + .../internal/handshake/updatable_aead.go | 323 +++ .../quic-go/internal/logutils/frame.go | 33 + .../internal/protocol/connection_id.go | 69 + .../internal/protocol/encryption_level.go | 30 + .../quic-go/internal/protocol/key_phase.go | 36 + .../internal/protocol/packet_number.go | 79 + .../quic-go/internal/protocol/params.go | 193 ++ .../quic-go/internal/protocol/perspective.go | 26 + .../quic-go/internal/protocol/protocol.go | 97 + .../quic-go/internal/protocol/stream.go | 76 + .../quic-go/internal/protocol/version.go | 114 + .../quic-go/internal/qerr/error_codes.go | 88 + .../quic-go/internal/qerr/errors.go | 124 + .../quic-go/internal/qtls/go116.go | 100 + .../quic-go/internal/qtls/go117.go | 100 + .../quic-go/internal/qtls/go118.go | 100 + .../quic-go/internal/qtls/go119.go | 100 + .../quic-go/internal/qtls/go120.go | 6 + .../quic-go/internal/qtls/go_oldversion.go | 7 + .../quic-go/internal/utils/atomic_bool.go | 22 + .../internal/utils/buffered_write_closer.go | 26 + .../internal/utils/byteinterval_linkedlist.go | 217 ++ .../quic-go/internal/utils/byteorder.go | 17 + .../internal/utils/byteorder_big_endian.go | 89 + .../quic-go/internal/utils/gen.go | 5 + .../quic-go/internal/utils/ip.go | 10 + .../quic-go/internal/utils/log.go | 131 ++ .../quic-go/internal/utils/minmax.go | 170 ++ .../internal/utils/new_connection_id.go | 12 + .../utils/newconnectionid_linkedlist.go | 217 ++ .../quic-go/internal/utils/packet_interval.go | 9 + .../utils/packetinterval_linkedlist.go | 217 ++ .../quic-go/internal/utils/rand.go | 29 + .../quic-go/internal/utils/rtt_stats.go | 127 ++ .../internal/utils/streamframe_interval.go | 9 + .../quic-go/internal/utils/timer.go | 53 + .../quic-go/internal/wire/ack_frame.go | 251 +++ .../quic-go/internal/wire/ack_range.go | 14 + .../internal/wire/connection_close_frame.go | 83 + .../quic-go/internal/wire/crypto_frame.go | 102 + .../internal/wire/data_blocked_frame.go | 38 + .../quic-go/internal/wire/datagram_frame.go | 85 + .../quic-go/internal/wire/extended_header.go | 249 ++ .../quic-go/internal/wire/frame_parser.go | 143 ++ .../internal/wire/handshake_done_frame.go | 28 + .../quic-go/internal/wire/header.go | 274 +++ .../quic-go/internal/wire/interface.go | 19 + .../quic-go/internal/wire/log.go | 72 + .../quic-go/internal/wire/max_data_frame.go | 40 + .../internal/wire/max_stream_data_frame.go | 46 + .../internal/wire/max_streams_frame.go | 55 + .../internal/wire/new_connection_id_frame.go | 80 + .../quic-go/internal/wire/new_token_frame.go | 48 + .../internal/wire/path_challenge_frame.go | 38 + .../internal/wire/path_response_frame.go | 38 + .../quic-go/internal/wire/ping_frame.go | 27 + .../quic-go/internal/wire/pool.go | 33 + .../internal/wire/reset_stream_frame.go | 58 + .../wire/retire_connection_id_frame.go | 36 + .../internal/wire/stop_sending_frame.go | 48 + .../wire/stream_data_blocked_frame.go | 46 + .../quic-go/internal/wire/stream_frame.go | 189 ++ .../internal/wire/streams_blocked_frame.go | 55 + .../internal/wire/transport_parameters.go | 476 ++++ .../internal/wire/version_negotiation.go | 54 + .../lucas-clemente/quic-go/logging/frame.go | 66 + .../quic-go/logging/interface.go | 134 ++ .../lucas-clemente/quic-go/logging/mockgen.go | 4 + .../quic-go/logging/multiplex.go | 219 ++ .../quic-go/logging/packet_header.go | 27 + .../lucas-clemente/quic-go/logging/types.go | 94 + .../lucas-clemente/quic-go/mockgen.go | 27 + .../lucas-clemente/quic-go/mockgen_private.sh | 49 + .../lucas-clemente/quic-go/mtu_discoverer.go | 74 + .../lucas-clemente/quic-go/multiplexer.go | 107 + .../quic-go/packet_handler_map.go | 489 ++++ .../lucas-clemente/quic-go/packet_packer.go | 894 ++++++++ .../lucas-clemente/quic-go/packet_unpacker.go | 196 ++ .../lucas-clemente/quic-go/quicvarint/io.go | 68 + .../quic-go/quicvarint/varint.go | 139 ++ .../lucas-clemente/quic-go/receive_stream.go | 331 +++ .../quic-go/retransmission_queue.go | 131 ++ .../lucas-clemente/quic-go/send_conn.go | 74 + .../lucas-clemente/quic-go/send_queue.go | 88 + .../lucas-clemente/quic-go/send_stream.go | 496 ++++ .../lucas-clemente/quic-go/server.go | 670 ++++++ .../lucas-clemente/quic-go/stream.go | 149 ++ .../lucas-clemente/quic-go/streams_map.go | 317 +++ .../quic-go/streams_map_generic_helper.go | 18 + .../quic-go/streams_map_incoming_bidi.go | 192 ++ .../quic-go/streams_map_incoming_generic.go | 190 ++ .../quic-go/streams_map_incoming_uni.go | 192 ++ .../quic-go/streams_map_outgoing_bidi.go | 226 ++ .../quic-go/streams_map_outgoing_generic.go | 224 ++ .../quic-go/streams_map_outgoing_uni.go | 226 ++ .../lucas-clemente/quic-go/sys_conn.go | 80 + .../lucas-clemente/quic-go/sys_conn_df.go | 16 + .../quic-go/sys_conn_df_linux.go | 40 + .../quic-go/sys_conn_df_windows.go | 46 + .../quic-go/sys_conn_helper_darwin.go | 22 + .../quic-go/sys_conn_helper_freebsd.go | 22 + .../quic-go/sys_conn_helper_linux.go | 20 + .../lucas-clemente/quic-go/sys_conn_no_oob.go | 16 + .../lucas-clemente/quic-go/sys_conn_oob.go | 257 +++ .../quic-go/sys_conn_windows.go | 40 + .../lucas-clemente/quic-go/token_store.go | 117 + .../lucas-clemente/quic-go/tools.go | 9 + .../quic-go/window_update_queue.go | 71 + .../marten-seemann/qpack/.codecov.yml | 7 + .../marten-seemann/qpack/.gitignore | 6 + .../marten-seemann/qpack/.gitmodules | 3 + .../marten-seemann/qpack/.golangci.yml | 27 + .../marten-seemann/qpack/LICENSE.md | 7 + .../github.com/marten-seemann/qpack/README.md | 21 + .../marten-seemann/qpack/decoder.go | 271 +++ .../marten-seemann/qpack/encoder.go | 95 + .../marten-seemann/qpack/header_field.go | 16 + .../marten-seemann/qpack/static_table.go | 254 +++ .../github.com/marten-seemann/qpack/varint.go | 66 + .../marten-seemann/qtls-go1-16/LICENSE | 27 + .../marten-seemann/qtls-go1-16/README.md | 6 + .../marten-seemann/qtls-go1-16/alert.go | 102 + .../marten-seemann/qtls-go1-16/auth.go | 289 +++ .../qtls-go1-16/cipher_suites.go | 532 +++++ .../marten-seemann/qtls-go1-16/common.go | 1576 +++++++++++++ .../marten-seemann/qtls-go1-16/common_js.go | 12 + .../marten-seemann/qtls-go1-16/common_nojs.go | 20 + .../marten-seemann/qtls-go1-16/conn.go | 1536 +++++++++++++ .../qtls-go1-16/handshake_client.go | 1105 +++++++++ .../qtls-go1-16/handshake_client_tls13.go | 740 ++++++ .../qtls-go1-16/handshake_messages.go | 1832 +++++++++++++++ .../qtls-go1-16/handshake_server.go | 878 ++++++++ .../qtls-go1-16/handshake_server_tls13.go | 912 ++++++++ .../qtls-go1-16/key_agreement.go | 338 +++ .../qtls-go1-16/key_schedule.go | 199 ++ .../marten-seemann/qtls-go1-16/prf.go | 283 +++ .../marten-seemann/qtls-go1-16/ticket.go | 259 +++ .../marten-seemann/qtls-go1-16/tls.go | 393 ++++ .../marten-seemann/qtls-go1-16/unsafe.go | 96 + .../marten-seemann/qtls-go1-17/LICENSE | 27 + .../marten-seemann/qtls-go1-17/README.md | 6 + .../marten-seemann/qtls-go1-17/alert.go | 102 + .../marten-seemann/qtls-go1-17/auth.go | 289 +++ .../qtls-go1-17/cipher_suites.go | 691 ++++++ .../marten-seemann/qtls-go1-17/common.go | 1485 ++++++++++++ .../marten-seemann/qtls-go1-17/conn.go | 1601 +++++++++++++ .../marten-seemann/qtls-go1-17/cpu.go | 22 + .../marten-seemann/qtls-go1-17/cpu_other.go | 12 + .../qtls-go1-17/handshake_client.go | 1111 +++++++++ .../qtls-go1-17/handshake_client_tls13.go | 732 ++++++ .../qtls-go1-17/handshake_messages.go | 1832 +++++++++++++++ .../qtls-go1-17/handshake_server.go | 905 ++++++++ .../qtls-go1-17/handshake_server_tls13.go | 896 ++++++++ .../qtls-go1-17/key_agreement.go | 357 +++ .../qtls-go1-17/key_schedule.go | 199 ++ .../marten-seemann/qtls-go1-17/prf.go | 283 +++ .../marten-seemann/qtls-go1-17/ticket.go | 274 +++ .../marten-seemann/qtls-go1-17/tls.go | 362 +++ .../marten-seemann/qtls-go1-17/unsafe.go | 96 + .../marten-seemann/qtls-go1-18/LICENSE | 27 + .../marten-seemann/qtls-go1-18/README.md | 6 + .../marten-seemann/qtls-go1-18/alert.go | 102 + .../marten-seemann/qtls-go1-18/auth.go | 289 +++ .../qtls-go1-18/cipher_suites.go | 691 ++++++ .../marten-seemann/qtls-go1-18/common.go | 1507 +++++++++++++ .../marten-seemann/qtls-go1-18/conn.go | 1608 +++++++++++++ .../marten-seemann/qtls-go1-18/cpu.go | 22 + .../marten-seemann/qtls-go1-18/cpu_other.go | 12 + .../qtls-go1-18/handshake_client.go | 1111 +++++++++ .../qtls-go1-18/handshake_client_tls13.go | 732 ++++++ .../qtls-go1-18/handshake_messages.go | 1831 +++++++++++++++ .../qtls-go1-18/handshake_server.go | 905 ++++++++ .../qtls-go1-18/handshake_server_tls13.go | 896 ++++++++ .../qtls-go1-18/key_agreement.go | 357 +++ .../qtls-go1-18/key_schedule.go | 199 ++ .../marten-seemann/qtls-go1-18/prf.go | 283 +++ .../marten-seemann/qtls-go1-18/ticket.go | 274 +++ .../marten-seemann/qtls-go1-18/tls.go | 362 +++ .../marten-seemann/qtls-go1-18/unsafe.go | 96 + .../marten-seemann/qtls-go1-19/LICENSE | 27 + .../marten-seemann/qtls-go1-19/README.md | 6 + .../marten-seemann/qtls-go1-19/alert.go | 102 + .../marten-seemann/qtls-go1-19/auth.go | 293 +++ .../qtls-go1-19/cipher_suites.go | 693 ++++++ .../marten-seemann/qtls-go1-19/common.go | 1512 +++++++++++++ .../marten-seemann/qtls-go1-19/conn.go | 1610 +++++++++++++ .../marten-seemann/qtls-go1-19/cpu.go | 22 + .../marten-seemann/qtls-go1-19/cpu_other.go | 12 + .../qtls-go1-19/handshake_client.go | 1117 +++++++++ .../qtls-go1-19/handshake_client_tls13.go | 736 ++++++ .../qtls-go1-19/handshake_messages.go | 1843 +++++++++++++++ .../qtls-go1-19/handshake_server.go | 905 ++++++++ .../qtls-go1-19/handshake_server_tls13.go | 900 ++++++++ .../qtls-go1-19/key_agreement.go | 357 +++ .../qtls-go1-19/key_schedule.go | 199 ++ .../marten-seemann/qtls-go1-19/notboring.go | 18 + .../marten-seemann/qtls-go1-19/prf.go | 283 +++ .../marten-seemann/qtls-go1-19/ticket.go | 274 +++ .../marten-seemann/qtls-go1-19/tls.go | 362 +++ .../marten-seemann/qtls-go1-19/unsafe.go | 96 + vendor/github.com/nxadm/tail/.gitignore | 3 + vendor/github.com/nxadm/tail/CHANGES.md | 56 + vendor/github.com/nxadm/tail/Dockerfile | 19 + vendor/github.com/nxadm/tail/LICENSE | 21 + vendor/github.com/nxadm/tail/README.md | 44 + .../github.com/nxadm/tail/ratelimiter/Licence | 7 + .../nxadm/tail/ratelimiter/leakybucket.go | 97 + .../nxadm/tail/ratelimiter/memory.go | 60 + .../nxadm/tail/ratelimiter/storage.go | 6 + vendor/github.com/nxadm/tail/tail.go | 455 ++++ vendor/github.com/nxadm/tail/tail_posix.go | 17 + vendor/github.com/nxadm/tail/tail_windows.go | 19 + vendor/github.com/nxadm/tail/util/util.go | 49 + .../nxadm/tail/watch/filechanges.go | 37 + vendor/github.com/nxadm/tail/watch/inotify.go | 136 ++ .../nxadm/tail/watch/inotify_tracker.go | 249 ++ vendor/github.com/nxadm/tail/watch/polling.go | 119 + vendor/github.com/nxadm/tail/watch/watch.go | 21 + .../github.com/nxadm/tail/winfile/winfile.go | 93 + vendor/github.com/onsi/ginkgo/LICENSE | 20 + .../github.com/onsi/ginkgo/config/config.go | 232 ++ .../onsi/ginkgo/formatter/formatter.go | 190 ++ .../onsi/ginkgo/ginkgo/bootstrap_command.go | 200 ++ .../onsi/ginkgo/ginkgo/build_command.go | 66 + .../ginkgo/ginkgo/convert/ginkgo_ast_nodes.go | 123 + .../onsi/ginkgo/ginkgo/convert/import.go | 90 + .../ginkgo/ginkgo/convert/package_rewriter.go | 128 ++ .../onsi/ginkgo/ginkgo/convert/test_finder.go | 56 + .../ginkgo/convert/testfile_rewriter.go | 162 ++ .../ginkgo/convert/testing_t_rewriter.go | 130 ++ .../onsi/ginkgo/ginkgo/convert_command.go | 51 + .../onsi/ginkgo/ginkgo/generate_command.go | 273 +++ .../onsi/ginkgo/ginkgo/help_command.go | 31 + .../interrupthandler/interrupt_handler.go | 52 + .../sigquit_swallower_unix.go | 14 + .../sigquit_swallower_windows.go | 7 + vendor/github.com/onsi/ginkgo/ginkgo/main.go | 308 +++ .../onsi/ginkgo/ginkgo/nodot/nodot.go | 196 ++ .../onsi/ginkgo/ginkgo/nodot_command.go | 77 + .../onsi/ginkgo/ginkgo/notifications.go | 141 ++ .../onsi/ginkgo/ginkgo/outline/ginkgo.go | 243 ++ .../onsi/ginkgo/ginkgo/outline/import.go | 65 + .../onsi/ginkgo/ginkgo/outline/outline.go | 107 + .../onsi/ginkgo/ginkgo/outline_command.go | 95 + .../onsi/ginkgo/ginkgo/run_command.go | 315 +++ .../run_watch_and_build_command_flags.go | 169 ++ .../onsi/ginkgo/ginkgo/suite_runner.go | 173 ++ .../ginkgo/ginkgo/testrunner/build_args.go | 7 + .../ginkgo/testrunner/build_args_old.go | 7 + .../ginkgo/ginkgo/testrunner/log_writer.go | 52 + .../ginkgo/ginkgo/testrunner/run_result.go | 27 + .../ginkgo/ginkgo/testrunner/test_runner.go | 554 +++++ .../ginkgo/ginkgo/testsuite/test_suite.go | 115 + .../ginkgo/testsuite/vendor_check_go15.go | 16 + .../ginkgo/testsuite/vendor_check_go16.go | 15 + .../onsi/ginkgo/ginkgo/unfocus_command.go | 180 ++ .../onsi/ginkgo/ginkgo/version_command.go | 24 + .../onsi/ginkgo/ginkgo/watch/delta.go | 22 + .../onsi/ginkgo/ginkgo/watch/delta_tracker.go | 75 + .../onsi/ginkgo/ginkgo/watch/dependencies.go | 92 + .../onsi/ginkgo/ginkgo/watch/package_hash.go | 104 + .../ginkgo/ginkgo/watch/package_hashes.go | 85 + .../onsi/ginkgo/ginkgo/watch/suite.go | 87 + .../onsi/ginkgo/ginkgo/watch_command.go | 175 ++ .../internal/codelocation/code_location.go | 48 + .../internal/containernode/container_node.go | 151 ++ .../onsi/ginkgo/internal/failer/failer.go | 92 + .../ginkgo/internal/leafnodes/benchmarker.go | 103 + .../ginkgo/internal/leafnodes/interfaces.go | 19 + .../onsi/ginkgo/internal/leafnodes/it_node.go | 47 + .../ginkgo/internal/leafnodes/measure_node.go | 62 + .../onsi/ginkgo/internal/leafnodes/runner.go | 117 + .../ginkgo/internal/leafnodes/setup_nodes.go | 48 + .../ginkgo/internal/leafnodes/suite_nodes.go | 55 + .../synchronized_after_suite_node.go | 90 + .../synchronized_before_suite_node.go | 181 ++ .../onsi/ginkgo/internal/remote/aggregator.go | 249 ++ .../internal/remote/forwarding_reporter.go | 147 ++ .../internal/remote/output_interceptor.go | 13 + .../remote/output_interceptor_unix.go | 82 + .../internal/remote/output_interceptor_win.go | 36 + .../onsi/ginkgo/internal/remote/server.go | 224 ++ .../onsi/ginkgo/internal/spec/spec.go | 247 ++ .../onsi/ginkgo/internal/spec/specs.go | 144 ++ .../internal/spec_iterator/index_computer.go | 55 + .../spec_iterator/parallel_spec_iterator.go | 59 + .../spec_iterator/serial_spec_iterator.go | 45 + .../sharded_parallel_spec_iterator.go | 47 + .../internal/spec_iterator/spec_iterator.go | 20 + .../ginkgo/internal/writer/fake_writer.go | 36 + .../onsi/ginkgo/internal/writer/writer.go | 89 + .../onsi/ginkgo/reporters/default_reporter.go | 87 + .../onsi/ginkgo/reporters/fake_reporter.go | 59 + .../onsi/ginkgo/reporters/junit_reporter.go | 178 ++ .../onsi/ginkgo/reporters/reporter.go | 15 + .../reporters/stenographer/console_logging.go | 64 + .../stenographer/fake_stenographer.go | 142 ++ .../reporters/stenographer/stenographer.go | 572 +++++ .../stenographer/support/go-colorable/LICENSE | 21 + .../support/go-colorable/README.md | 43 + .../support/go-colorable/colorable_others.go | 24 + .../support/go-colorable/colorable_windows.go | 783 +++++++ .../support/go-colorable/noncolorable.go | 57 + .../stenographer/support/go-isatty/LICENSE | 9 + .../stenographer/support/go-isatty/README.md | 37 + .../stenographer/support/go-isatty/doc.go | 2 + .../support/go-isatty/isatty_appengine.go | 9 + .../support/go-isatty/isatty_bsd.go | 18 + .../support/go-isatty/isatty_linux.go | 18 + .../support/go-isatty/isatty_solaris.go | 16 + .../support/go-isatty/isatty_windows.go | 19 + .../ginkgo/reporters/teamcity_reporter.go | 106 + .../onsi/ginkgo/types/code_location.go | 15 + .../onsi/ginkgo/types/deprecation_support.go | 150 ++ .../onsi/ginkgo/types/synchronization.go | 30 + vendor/github.com/onsi/ginkgo/types/types.go | 174 ++ vendor/golang.org/x/crypto/cryptobyte/asn1.go | 809 +++++++ .../x/crypto/cryptobyte/asn1/asn1.go | 46 + .../golang.org/x/crypto/cryptobyte/builder.go | 337 +++ .../golang.org/x/crypto/cryptobyte/string.go | 161 ++ vendor/gopkg.in/tomb.v1/LICENSE | 29 + vendor/gopkg.in/tomb.v1/README.md | 4 + vendor/gopkg.in/tomb.v1/tomb.go | 176 ++ vendor/modules.txt | 79 +- 429 files changed, 87237 insertions(+), 7 deletions(-) create mode 100644 vendor/github.com/cheekybits/genny/.gitignore create mode 100644 vendor/github.com/cheekybits/genny/.travis.yml create mode 100644 vendor/github.com/cheekybits/genny/LICENSE create mode 100644 vendor/github.com/cheekybits/genny/README.md create mode 100644 vendor/github.com/cheekybits/genny/doc.go create mode 100644 vendor/github.com/cheekybits/genny/generic/doc.go create mode 100644 vendor/github.com/cheekybits/genny/generic/generic.go create mode 100644 vendor/github.com/cheekybits/genny/main.go create mode 100644 vendor/github.com/cheekybits/genny/out/lazy_file.go create mode 100644 vendor/github.com/cheekybits/genny/parse/builtins.go create mode 100644 vendor/github.com/cheekybits/genny/parse/doc.go create mode 100644 vendor/github.com/cheekybits/genny/parse/errors.go create mode 100644 vendor/github.com/cheekybits/genny/parse/parse.go create mode 100644 vendor/github.com/cheekybits/genny/parse/typesets.go create mode 100644 vendor/github.com/go-task/slim-sprig/.editorconfig create mode 100644 vendor/github.com/go-task/slim-sprig/.gitattributes create mode 100644 vendor/github.com/go-task/slim-sprig/.gitignore create mode 100644 vendor/github.com/go-task/slim-sprig/CHANGELOG.md create mode 100644 vendor/github.com/go-task/slim-sprig/LICENSE.txt create mode 100644 vendor/github.com/go-task/slim-sprig/README.md create mode 100644 vendor/github.com/go-task/slim-sprig/Taskfile.yml create mode 100644 vendor/github.com/go-task/slim-sprig/crypto.go create mode 100644 vendor/github.com/go-task/slim-sprig/date.go create mode 100644 vendor/github.com/go-task/slim-sprig/defaults.go create mode 100644 vendor/github.com/go-task/slim-sprig/dict.go create mode 100644 vendor/github.com/go-task/slim-sprig/doc.go create mode 100644 vendor/github.com/go-task/slim-sprig/functions.go create mode 100644 vendor/github.com/go-task/slim-sprig/list.go create mode 100644 vendor/github.com/go-task/slim-sprig/network.go create mode 100644 vendor/github.com/go-task/slim-sprig/numeric.go create mode 100644 vendor/github.com/go-task/slim-sprig/reflect.go create mode 100644 vendor/github.com/go-task/slim-sprig/regex.go create mode 100644 vendor/github.com/go-task/slim-sprig/strings.go create mode 100644 vendor/github.com/go-task/slim-sprig/url.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/.gitignore create mode 100644 vendor/github.com/lucas-clemente/quic-go/.golangci.yml create mode 100644 vendor/github.com/lucas-clemente/quic-go/Changelog.md create mode 100644 vendor/github.com/lucas-clemente/quic-go/LICENSE create mode 100644 vendor/github.com/lucas-clemente/quic-go/README.md create mode 100644 vendor/github.com/lucas-clemente/quic-go/buffer_pool.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/client.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/closed_conn.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/codecov.yml create mode 100644 vendor/github.com/lucas-clemente/quic-go/config.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/conn_id_generator.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/conn_id_manager.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/connection.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/crypto_stream.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/crypto_stream_manager.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/datagram_queue.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/errors.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/frame_sorter.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/framer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/body.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/client.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/error_codes.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/frames.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/gzip_reader.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/http_stream.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/request.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/request_writer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/response_writer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/roundtrip.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/http3/server.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/interface.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ack_eliciting.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ackhandler.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/mockgen.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_number_generator.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_tracker.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/congestion/pacer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/header_protector.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/hkdf.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/initial_aead.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/retry.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/session_ticket.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_generator.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_protector.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/updatable_aead.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/logutils/frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/key_phase.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/params.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qerr/error_codes.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qerr/errors.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go116.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go117.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go118.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go119.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go120.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go_oldversion.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/atomic_bool.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/buffered_write_closer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/ip.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/new_connection_id.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/newconnectionid_linkedlist.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/rand.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/rtt_stats.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/datagram_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/handshake_done_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/interface.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/max_streams_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/new_connection_id_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/new_token_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/pool.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/reset_stream_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/retire_connection_id_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_data_blocked_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/streams_blocked_frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/transport_parameters.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/logging/frame.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/logging/interface.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/logging/mockgen.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/logging/multiplex.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/logging/packet_header.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/logging/types.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/mockgen.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh create mode 100644 vendor/github.com/lucas-clemente/quic-go/mtu_discoverer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/multiplexer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/packet_packer.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/quicvarint/io.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/quicvarint/varint.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/receive_stream.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/retransmission_queue.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/send_conn.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/send_queue.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/send_stream.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/server.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/stream.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_df.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_df_linux.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_df_windows.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_darwin.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_freebsd.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_linux.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_no_oob.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_oob.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/sys_conn_windows.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/token_store.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/tools.go create mode 100644 vendor/github.com/lucas-clemente/quic-go/window_update_queue.go create mode 100644 vendor/github.com/marten-seemann/qpack/.codecov.yml create mode 100644 vendor/github.com/marten-seemann/qpack/.gitignore create mode 100644 vendor/github.com/marten-seemann/qpack/.gitmodules create mode 100644 vendor/github.com/marten-seemann/qpack/.golangci.yml create mode 100644 vendor/github.com/marten-seemann/qpack/LICENSE.md create mode 100644 vendor/github.com/marten-seemann/qpack/README.md create mode 100644 vendor/github.com/marten-seemann/qpack/decoder.go create mode 100644 vendor/github.com/marten-seemann/qpack/encoder.go create mode 100644 vendor/github.com/marten-seemann/qpack/header_field.go create mode 100644 vendor/github.com/marten-seemann/qpack/static_table.go create mode 100644 vendor/github.com/marten-seemann/qpack/varint.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/LICENSE create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/README.md create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/alert.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/auth.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/cipher_suites.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/common.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/common_js.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/common_nojs.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/conn.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_client.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_client_tls13.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_messages.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_server.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_server_tls13.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/key_agreement.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/key_schedule.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/prf.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/ticket.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/tls.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/unsafe.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/LICENSE create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/README.md create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/alert.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/auth.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/cipher_suites.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/common.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/conn.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/cpu.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/cpu_other.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_client.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_client_tls13.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_messages.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_server.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_server_tls13.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/key_agreement.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/key_schedule.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/prf.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/ticket.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/tls.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/unsafe.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/LICENSE create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/README.md create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/alert.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/auth.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/cipher_suites.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/common.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/conn.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/cpu.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/cpu_other.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/handshake_client.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/handshake_client_tls13.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/handshake_messages.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/handshake_server.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/handshake_server_tls13.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/key_agreement.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/key_schedule.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/prf.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/ticket.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/tls.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/unsafe.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/LICENSE create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/README.md create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/alert.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/auth.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/cipher_suites.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/common.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/conn.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/cpu.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/cpu_other.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/handshake_client.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/handshake_client_tls13.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/handshake_messages.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/handshake_server.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/handshake_server_tls13.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/key_agreement.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/key_schedule.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/notboring.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/prf.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/ticket.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/tls.go create mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/unsafe.go create mode 100644 vendor/github.com/nxadm/tail/.gitignore create mode 100644 vendor/github.com/nxadm/tail/CHANGES.md create mode 100644 vendor/github.com/nxadm/tail/Dockerfile create mode 100644 vendor/github.com/nxadm/tail/LICENSE create mode 100644 vendor/github.com/nxadm/tail/README.md create mode 100644 vendor/github.com/nxadm/tail/ratelimiter/Licence create mode 100644 vendor/github.com/nxadm/tail/ratelimiter/leakybucket.go create mode 100644 vendor/github.com/nxadm/tail/ratelimiter/memory.go create mode 100644 vendor/github.com/nxadm/tail/ratelimiter/storage.go create mode 100644 vendor/github.com/nxadm/tail/tail.go create mode 100644 vendor/github.com/nxadm/tail/tail_posix.go create mode 100644 vendor/github.com/nxadm/tail/tail_windows.go create mode 100644 vendor/github.com/nxadm/tail/util/util.go create mode 100644 vendor/github.com/nxadm/tail/watch/filechanges.go create mode 100644 vendor/github.com/nxadm/tail/watch/inotify.go create mode 100644 vendor/github.com/nxadm/tail/watch/inotify_tracker.go create mode 100644 vendor/github.com/nxadm/tail/watch/polling.go create mode 100644 vendor/github.com/nxadm/tail/watch/watch.go create mode 100644 vendor/github.com/nxadm/tail/winfile/winfile.go create mode 100644 vendor/github.com/onsi/ginkgo/LICENSE create mode 100644 vendor/github.com/onsi/ginkgo/config/config.go create mode 100644 vendor/github.com/onsi/ginkgo/formatter/formatter.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/bootstrap_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/build_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/ginkgo_ast_nodes.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/import.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/package_rewriter.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/test_finder.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/testfile_rewriter.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/testing_t_rewriter.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/generate_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/help_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/interrupt_handler.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_unix.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_windows.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/main.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/nodot/nodot.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/nodot_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/notifications.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/outline/ginkgo.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/outline/import.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/outline/outline.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/outline_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/run_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/run_watch_and_build_command_flags.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/suite_runner.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args_old.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/log_writer.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/run_result.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/test_runner.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testsuite/test_suite.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go15.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go16.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/unfocus_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/version_command.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/watch/delta.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/watch/delta_tracker.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/watch/dependencies.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hash.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hashes.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/watch/suite.go create mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/watch_command.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/codelocation/code_location.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/containernode/container_node.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/failer/failer.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/benchmarker.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/interfaces.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/it_node.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/measure_node.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/runner.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/setup_nodes.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/suite_nodes.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_after_suite_node.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_before_suite_node.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/aggregator.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/forwarding_reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_unix.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_win.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/server.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/spec/spec.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/spec/specs.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/index_computer.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/parallel_spec_iterator.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/serial_spec_iterator.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/sharded_parallel_spec_iterator.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/spec_iterator.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/writer/fake_writer.go create mode 100644 vendor/github.com/onsi/ginkgo/internal/writer/writer.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/default_reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/fake_reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/junit_reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/console_logging.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/fake_stenographer.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/stenographer.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/LICENSE create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/README.md create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_others.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_windows.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/noncolorable.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/LICENSE create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/README.md create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/doc.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_appengine.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_bsd.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_linux.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_solaris.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_windows.go create mode 100644 vendor/github.com/onsi/ginkgo/reporters/teamcity_reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/types/code_location.go create mode 100644 vendor/github.com/onsi/ginkgo/types/deprecation_support.go create mode 100644 vendor/github.com/onsi/ginkgo/types/synchronization.go create mode 100644 vendor/github.com/onsi/ginkgo/types/types.go create mode 100644 vendor/golang.org/x/crypto/cryptobyte/asn1.go create mode 100644 vendor/golang.org/x/crypto/cryptobyte/asn1/asn1.go create mode 100644 vendor/golang.org/x/crypto/cryptobyte/builder.go create mode 100644 vendor/golang.org/x/crypto/cryptobyte/string.go create mode 100644 vendor/gopkg.in/tomb.v1/LICENSE create mode 100644 vendor/gopkg.in/tomb.v1/README.md create mode 100644 vendor/gopkg.in/tomb.v1/tomb.go diff --git a/dnscrypt-proxy/xtransport.go b/dnscrypt-proxy/xtransport.go index d86a2c85..32d802df 100644 --- a/dnscrypt-proxy/xtransport.go +++ b/dnscrypt-proxy/xtransport.go @@ -22,6 +22,7 @@ import ( "github.com/jedisct1/dlog" stamps "github.com/jedisct1/go-dnsstamps" + "github.com/lucas-clemente/quic-go/http3" "github.com/miekg/dns" "golang.org/x/net/http2" netproxy "golang.org/x/net/proxy" @@ -46,11 +47,18 @@ type CachedIPs struct { cache map[string]*CachedIPItem } +type AltSupport struct { + sync.RWMutex + cache map[string]uint16 +} + type XTransport struct { transport *http.Transport + h3Transport *http3.RoundTripper keepAlive time.Duration timeout time.Duration cachedIPs CachedIPs + altSupport AltSupport bootstrapResolvers []string mainProto string ignoreSystemDNS bool @@ -69,6 +77,7 @@ func NewXTransport() *XTransport { } xTransport := XTransport{ cachedIPs: CachedIPs{cache: make(map[string]*CachedIPItem)}, + altSupport: AltSupport{cache: make(map[string]uint16)}, keepAlive: DefaultKeepAlive, timeout: DefaultTimeout, bootstrapResolvers: []string{DefaultBootstrapResolver}, @@ -212,6 +221,8 @@ func (xTransport *XTransport) rebuildTransport() { http2Transport.AllowHTTP = false } xTransport.transport = transport + h3Transport := &http3.RoundTripper{DisableCompression: true, TLSClientConfig: &tlsClientConfig} + xTransport.h3Transport = h3Transport } func (xTransport *XTransport) resolveUsingSystem(host string) (ip net.IP, ttl time.Duration, err error) { @@ -379,7 +390,20 @@ func (xTransport *XTransport) Fetch( if timeout <= 0 { timeout = xTransport.timeout } - client := http.Client{Transport: xTransport.transport, Timeout: timeout} + client := http.Client{ + Transport: xTransport.transport, + Timeout: timeout, + } + host, port := ExtractHostAndPort(url.Host, 443) + xTransport.altSupport.RLock() + altPort, hasAltSupport := xTransport.altSupport.cache[url.Host] + xTransport.altSupport.RUnlock() + if hasAltSupport { + if int(altPort) == port { + client.Transport = xTransport.h3Transport + dlog.Debugf("Using HTTP/3 transport for [%s]", url.Host) + } + } header := map[string][]string{"User-Agent": {"dnscrypt-proxy"}} if len(accept) > 0 { header["Accept"] = []string{accept} @@ -396,7 +420,6 @@ func (xTransport *XTransport) Fetch( url2.RawQuery = qs.Encode() url = &url2 } - host, _ := ExtractHostAndPort(url.Host, 0) if xTransport.proxyDialer == nil && strings.HasSuffix(host, ".onion") { return nil, 0, nil, 0, errors.New("Onion service is not reachable without Tor") } @@ -444,6 +467,32 @@ func (xTransport *XTransport) Fetch( } return nil, statusCode, nil, rtt, err } + if !hasAltSupport { + if alt, found := resp.Header["Alt-Svc"]; found { + dlog.Debugf("Alt-Svc [%s]: [%s]", url.Host, alt) + altPort := uint16(port) + for i, xalt := range alt { + for j, v := range strings.Split(xalt, ";") { + if i > 8 || j > 16 { + break + } + v = strings.TrimSpace(v) + if strings.HasPrefix(v, "h3=\":") { + v = strings.TrimPrefix(v, "h3=\":") + v = strings.TrimSuffix(v, "\"") + if xAltPort, err := strconv.ParseUint(v, 10, 16); err == nil && xAltPort <= 65536 { + altPort = uint16(xAltPort) + dlog.Debugf("Using HTTP/3 for [%s]", url.Host) + break + } + } + } + } + xTransport.altSupport.Lock() + xTransport.altSupport.cache[url.Host] = altPort + xTransport.altSupport.Unlock() + } + } tls := resp.TLS bin, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength)) if err != nil { diff --git a/go.mod b/go.mod index 70954f0e..71aafbce 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/jedisct1/xsecretbox v0.0.0-20210927135450-ebe41aef7bef github.com/k-sone/critbitgo v1.4.0 github.com/kardianos/service v1.2.1 + github.com/lucas-clemente/quic-go v0.28.0 github.com/miekg/dns v1.1.50 github.com/powerman/check v1.6.0 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d @@ -40,6 +41,7 @@ require ( github.com/cespare/xxhash/v2 v2.1.1 // indirect github.com/charithe/durationcheck v0.0.8 // indirect github.com/chavacava/garif v0.0.0-20210405164556-e8a0a408d6af // indirect + github.com/cheekybits/genny v1.0.0 // indirect github.com/daixiang0/gci v0.2.8 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/denis-tingajkin/go-header v0.4.2 // indirect @@ -50,6 +52,7 @@ require ( github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/fzipp/gocyclo v0.3.1 // indirect github.com/go-critic/go-critic v0.5.6 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/go-toolsmith/astcast v1.0.0 // indirect github.com/go-toolsmith/astcopy v1.0.0 // indirect github.com/go-toolsmith/astequal v1.0.0 // indirect @@ -60,7 +63,7 @@ require ( github.com/go-xmlfmt/xmlfmt v0.0.0-20191208150333-d5b6f63a941b // indirect github.com/gobwas/glob v0.2.3 // indirect github.com/gofrs/flock v0.8.0 // indirect - github.com/golang/protobuf v1.5.0 // indirect + github.com/golang/protobuf v1.5.2 // indirect github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2 // indirect github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect github.com/golangci/go-misc v0.0.0-20180628070357-927a3d87b613 // indirect @@ -95,6 +98,11 @@ require ( github.com/ldez/tagliatelle v0.2.0 // indirect github.com/magiconair/properties v1.8.1 // indirect github.com/maratori/testpackage v1.0.1 // indirect + github.com/marten-seemann/qpack v0.2.1 // indirect + github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect + github.com/marten-seemann/qtls-go1-17 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 // indirect github.com/matoous/godox v0.0.0-20210227103229-6504466cf951 // indirect github.com/mattn/go-colorable v0.1.8 // indirect github.com/mattn/go-isatty v0.0.12 // indirect @@ -111,7 +119,9 @@ require ( github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 // indirect github.com/nishanths/exhaustive v0.1.0 // indirect github.com/nishanths/predeclared v0.2.1 // indirect + github.com/nxadm/tail v1.4.8 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect + github.com/onsi/ginkgo v1.16.4 // indirect github.com/pelletier/go-toml v1.2.0 // indirect github.com/phayes/checkstyle v0.0.0-20170904204023-bfd46e6a821d // indirect github.com/pkg/errors v0.9.1 // indirect @@ -160,6 +170,7 @@ require ( google.golang.org/grpc v1.38.0 // indirect google.golang.org/protobuf v1.27.0 // indirect gopkg.in/ini.v1 v1.51.0 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect honnef.co/go/tools v0.2.0 // indirect diff --git a/go.sum b/go.sum index b42bd2f3..472b76cf 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,9 @@ 4d63.com/gochecknoglobals v0.0.0-20201008074935-acfc0b28355a/go.mod h1:wfdC5ZjKSPr7CybKEcgJhUOgeAQW1+7WcyK8OvUilfo= bitbucket.org/creachadair/shell v0.0.6/go.mod h1:8Qqi/cYk7vPnsOePHroKXDJYmb5x7ENhtiFtfZq8K+M= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= @@ -35,7 +37,12 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= contrib.go.opencensus.io/exporter/stackdriver v0.13.4/go.mod h1:aXENhDJ1Y4lIg4EUaVTwzvYETVNZk10Pu26tevFKLUc= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.1.0 h1:ksErzDEI1khOiGPgpwuI7x2ebx/uXQNw7xJpn9Eq1+I= github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= @@ -61,6 +68,7 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alexkohler/prealloc v1.0.0 h1:Hbq0/3fJPQhNkN0dR95AVrr6R7tou91y0uHG5pOcUuw= github.com/alexkohler/prealloc v1.0.0/go.mod h1:VetnK3dIgFBBKmg0YnD9F9x6Icjd+9cvfHR56wJVlKE= github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= github.com/antihax/optional v0.0.0-20180407024304-ca021399b1a6/go.mod h1:V8iCPQYkqmusNa815XgQio277wI47sdRh1dUOLdyC6Q= github.com/aokoli/goutils v1.0.1/go.mod h1:SijmP0QR8LtwsmDs8Yii5Z/S4trXFGFC2oO5g9DP+DQ= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= @@ -84,6 +92,8 @@ github.com/bkielbasa/cyclop v1.2.0 h1:7Jmnh0yL2DjKfw28p86YTd/B4lRGcNuu12sKE35sM7 github.com/bkielbasa/cyclop v1.2.0/go.mod h1:qOI0yy6A7dYC4Zgsa72Ppm9kONl0RoIlPbzot9mhmeI= github.com/bombsimon/wsl/v3 v3.3.0 h1:Mka/+kRLoQJq7g2rggtgQsjuI/K5Efd87WX96EWFxjM= github.com/bombsimon/wsl/v3 v3.3.0/go.mod h1:st10JtZYLE4D5sC7b8xV4zTKZwAQjCH/Hy2Pm1FNZIc= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= @@ -92,6 +102,8 @@ github.com/charithe/durationcheck v0.0.8 h1:cnZrThioNW9gSV5JsRIXmkyHUbcDH7Y9hkzF github.com/charithe/durationcheck v0.0.8/go.mod h1:SSbRIBVfMjCi/kEB6K65XEA83D6prSM8ap1UCpNKtgg= github.com/chavacava/garif v0.0.0-20210405164556-e8a0a408d6af h1:spmv8nSH9h5oCQf40jt/ufBCt9j0/58u4G+rkeMqXGI= github.com/chavacava/garif v0.0.0-20210405164556-e8a0a408d6af/go.mod h1:Qjyv4H3//PWVzTeCezG2b9IRn6myJxJSr4TD/xo6ojU= +github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= +github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -106,6 +118,7 @@ github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8Nz github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190620071333-e64a0ec8b42a/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= @@ -148,6 +161,8 @@ github.com/fatih/color v1.12.0 h1:mRhaKNwANqRgUBGKmnI5ZxEk7QXmjQeCcuYFMX2bfcc= github.com/fatih/color v1.12.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -155,8 +170,10 @@ github.com/fullstorydev/grpcurl v1.6.0/go.mod h1:ZQ+ayqbKMJNhzLmbpCiurTVlaK2M/3n github.com/fzipp/gocyclo v0.3.1 h1:A9UeX3HJSXTBzvHzhqoYVuE0eAhe+aM8XBCCwsPMZOc= github.com/fzipp/gocyclo v0.3.1/go.mod h1:DJHO6AUmbdqj2ET4Z9iArSuwWgYDRryYt2wASxc7x3E= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-critic/go-critic v0.5.6 h1:siUR1+322iVikWXoV75I1YRfNaC/yaLzhdF9Zwd8Tus= github.com/go-critic/go-critic v0.5.6/go.mod h1:cVjj0DfqewQVIlIAGexPCaGaZDAqGE29PYDDADIVNEo= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -169,6 +186,7 @@ github.com/go-redis/redis v6.15.8+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8w github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-toolsmith/astcast v1.0.0 h1:JojxlmI6STnFVG9yOImLeGREv8W2ocNUM+iOhR6jE7g= github.com/go-toolsmith/astcast v1.0.0/go.mod h1:mt2OdQTeAQcY4DQgPSArJjHCcOwlX+Wl/kwN+LbLGQ4= @@ -204,6 +222,7 @@ github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4er github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -211,6 +230,8 @@ github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.1.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -226,8 +247,9 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2 h1:23T5iq8rbUYlhpt5DB4XJkc6BU31uODLD1o1gKvZmD0= github.com/golangci/check v0.0.0-20180506172741-cfe4005ccda2/go.mod h1:k9Qvh+8juN+UKMCS/3jFtGICgW8O96FVaZsaxdzDkR4= github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a h1:w8hkcTqaFpzKqonE9uMCefW1WDie15eSP/4MssdenaM= @@ -262,6 +284,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -277,6 +301,8 @@ github.com/google/uuid v0.0.0-20161128191214-064e2069ce9c/go.mod h1:TIyPZe4Mgqvf github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQHCoQ= @@ -302,11 +328,13 @@ github.com/gostaticanalysis/forcetypeassert v0.0.0-20200621232751-01d4955beaa5 h github.com/gostaticanalysis/forcetypeassert v0.0.0-20200621232751-01d4955beaa5/go.mod h1:qZEedyP/sY1lTGV1uJ3VhWZ2mqag3IkWsDHVbplHXak= github.com/gostaticanalysis/nilerr v0.1.1 h1:ThE+hJP0fEp4zWLkWHWcRyI2Od0p7DlgYG3Uqrmrcpk= github.com/gostaticanalysis/nilerr v0.1.1/go.mod h1:wZYb6YI5YAxxq0i1+VJbY0s2YONW0HU0GPE3+5PWN4A= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.2.2/go.mod h1:EaizFBKfUKtMIF5iaDEhniwNedqGo9FuLFzppDr3uwI= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.12.1/go.mod h1:8XEsbTttt/W+VvjtQhLACqCisSPWTxCZ7sBRjU6iH9c= @@ -362,6 +390,7 @@ github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b h1:ZGiXF8sz7P github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b/go.mod h1:hQmNrgofl+IY/8L+n20H6E6PWBBTokdsv+q49j0QhsU= github.com/jedisct1/xsecretbox v0.0.0-20210927135450-ebe41aef7bef h1:1Jom8JnCkrgivikTdt0lg5lHpZRvpP98hn8H1bIjFLk= github.com/jedisct1/xsecretbox v0.0.0-20210927135450-ebe41aef7bef/go.mod h1:dmX1e+PPjjbMjNI/wJk8EgjXmqAMZ5tgOzD1wxCgzhs= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jgautheron/goconst v1.5.1 h1:HxVbL1MhydKs8R8n/HE5NPvzfaYmQJA3o879lE4+WcM= github.com/jgautheron/goconst v1.5.1/go.mod h1:aAosetZ5zaeC/2EfMeRswtxUFBpe2Hr7HzkgX4fanO4= github.com/jhump/protoreflect v1.6.1/go.mod h1:RZQ/lnuN+zqeRVpQigTwO6o0AJUkxbnSnpuG7toUTG4= @@ -405,6 +434,7 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= @@ -424,11 +454,25 @@ github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.9.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/logrusorgru/aurora v0.0.0-20181002194514-a7b3b318ed4e/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= +github.com/lucas-clemente/quic-go v0.28.0 h1:9eXVRgIkMQQyiyorz/dAaOYIx3TFzXsIFkNFz4cxuJM= +github.com/lucas-clemente/quic-go v0.28.0/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4= github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/maratori/testpackage v1.0.1 h1:QtJ5ZjqapShm0w5DosRjg0PRlSdAdlx+W6cCKoALdbQ= github.com/maratori/testpackage v1.0.1/go.mod h1:ddKdw+XG0Phzhx8BFDTKgpWP4i7MpApTE5fXSKAqwDU= +github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= +github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= +github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= +github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 h1:7m/WlWcSROrcK5NxuXaxYD32BZqe/LEEnBrWcH/cOqQ= +github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/matoous/godox v0.0.0-20210227103229-6504466cf951 h1:pWxk9e//NbPwfxat7RXkts09K+dEBJWakUWwICVqYbA= github.com/matoous/godox v0.0.0-20210227103229-6504466cf951/go.mod h1:1BELzlh859Sh1c6+90blK8lbYy0kwQf1bYlBhBysy1s= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= @@ -459,6 +503,7 @@ github.com/mgechev/dots v0.0.0-20190921121421-c36f7dcfbb81 h1:QASJXOGm2RZ5Ardbc8 github.com/mgechev/dots v0.0.0-20190921121421-c36f7dcfbb81/go.mod h1:KQ7+USdGKfpPjXk4Ga+5XxQM4Lm4e3gAogrreFAYpOg= github.com/mgechev/revive v1.0.7 h1:5kEWTY/W5a0Eiqnkn2BAWsRZpxbs1ft15PsyNC7Rml8= github.com/mgechev/revive v1.0.7/go.mod h1:vuE5ox/4L/HDd63MCcCk3H6wTLQ6XXezRphJ8cJJOxY= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.35/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= @@ -495,6 +540,8 @@ github.com/nakabonne/nestif v0.3.0 h1:+yOViDGhg8ygGrmII72nV9B/zGxY188TYpfolntsaP github.com/nakabonne/nestif v0.3.0/go.mod h1:dI314BppzXjJ4HsCnbo7XzrJHPszZsjnk5wEBSYHI2c= github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354 h1:4kuARK6Y6FxaNu/BnU2OAaLF86eTVhP2hjTB6iMvItA= github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354/go.mod h1:KSVJerMDfblTH7p5MZaTt+8zaT2iEk3AkVb9PQdZuE8= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/nishanths/exhaustive v0.1.0 h1:kVlMw8h2LHPMGUVqUj6230oQjjTMFjwcZrnkhXzFfl8= @@ -514,13 +561,18 @@ github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6 github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.16.1 h1:foqVmeWDD6yYpK+Yz3fHyNIxFYNxswxqNFjSKe+vI54= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/ginkgo v1.16.1/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.11.0 h1:+CqWgvj0OZycCaqclBD1pxKHAU+tOkHmQIWvDHq2aug= github.com/onsi/gomega v1.11.0/go.mod h1:azGKhqFUon9Vuj0YmTfLSmx0FUwqXYSTl5re8lQLTUg= +github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= @@ -542,6 +594,7 @@ github.com/powerman/check v1.6.0 h1:8J3QQFD2QaeMu+pDZz+EChgY62IQtRfpUSD7WHBEgg4= github.com/powerman/check v1.6.0/go.mod h1:sHb2NBEd2MdSGldglajwRsC1yxjcfngva5CGGsIMJDM= github.com/powerman/deepequal v0.1.0 h1:sVwtyTsBuYIvdbLR1O2wzRY63YgPqdGZmk/o80l+C/U= github.com/powerman/deepequal v0.1.0/go.mod h1:3k7aG/slufBhUANdN67o/UPg8i5YaiJ6FmibWX0cn04= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= @@ -552,11 +605,13 @@ github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1: github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0 h1:RyRA7RzGXQZiW+tGMr7sxa85G1z0yOpM1qq5c8lNawc= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= @@ -592,13 +647,34 @@ github.com/sanposhiho/wastedassign/v2 v2.0.6/go.mod h1:KyZ0MWTwxxBmfwn33zh3k1dms github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/securego/gosec/v2 v2.8.0 h1:iHg9cVmHWf5n6/ijUJ4F10h5bKlNtvXmcWzRw0lxiKE= github.com/securego/gosec/v2 v2.8.0/go.mod h1:hJZ6NT5TqoY+jmOsaxAV4cXoEdrMRLVaNPnSpUCvCZs= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c h1:W65qqJCIOVP4jpqPQ0YvHYKwcMEMVWIzWC5iNQQfBTU= github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c/go.mod h1:/PevMnwAxekIXwN8qQyfc5gl2NlkB3CQlkizAbOkeBs= github.com/shirou/gopsutil/v3 v3.21.5/go.mod h1:ghfMypLDrFSWN2c9cDYFLHyynQ+QUht0cv/18ZqVczw= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= @@ -612,8 +688,10 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9 github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/sonatard/noctx v0.0.1 h1:VC1Qhl6Oxx9vvWo3UDgrGXYCeKCe3Wbw7qAWL6FrmTY= github.com/sonatard/noctx v0.0.1/go.mod h1:9D2D/EoULe8Yy2joDHJj7bv3sZoq9AaSb8B4lqBjiZI= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/go-diff v0.6.1 h1:hmA1LzxW0n1c3Q4YbrFgg4P99GSnebYa3x8gr0HZqLQ= github.com/sourcegraph/go-diff v0.6.1/go.mod h1:iBszgVvyxdc8SFZ7gm69go2KDdt3ag071iBaWPF6cjs= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= @@ -649,6 +727,7 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tdakkota/asciicheck v0.0.0-20200416200610-e657995f937b h1:HxLVTlqcHhFAz3nWUcuvpH7WuOMv8LQoCWmruLfFH2U= github.com/tdakkota/asciicheck v0.0.0-20200416200610-e657995f937b/go.mod h1:yHp0ai0Z9gUljN3o0xMhYJnH/IcvkdTBOX2fmJ93JEM= github.com/tetafro/godot v1.4.7 h1:zBaoSY4JRVVz33y/qnODsdaKj2yAaMr91HCbqHCifVc= @@ -678,6 +757,8 @@ github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyC github.com/valyala/fasthttp v1.16.0/go.mod h1:YOKImeEosDdBPnxc0gy7INqi3m1zK6A+xl6TwOBhHCA= github.com/valyala/quicktemplate v1.6.3/go.mod h1:fwPzK2fHuYEODzJ9pkw0ipCPNHZ2tD5KW4lOuSdPKzY= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= +github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= @@ -697,6 +778,7 @@ go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/etcd v0.0.0-20200513171258-e048e166ab9c/go.mod h1:xCI7ZzBfRuGgBXyXO6yfWfDmlWd35khcWpUa4L0xI/k= go.mozilla.org/mozlog v0.0.0-20170222151521-4bb13139d403/go.mod h1:jHoPAGnDrCy6kaI2tAze5Prf0Nr0w/oNkROt2lw3n3o= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -711,14 +793,19 @@ go.uber.org/multierr v1.4.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+ go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20180501155221-613d6eafa307/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= @@ -736,6 +823,7 @@ golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMk golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -762,12 +850,15 @@ golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -793,20 +884,26 @@ golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220708220712-1185a9018129 h1:vucSRfWwTsoXro7P+3Cjlr6flUMtzCwzlvkxEQtHHB0= golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -824,12 +921,14 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -862,6 +961,7 @@ golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -874,12 +974,15 @@ golang.org/x/sys v0.0.0-20210217105451-b926d437f341/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e h1:NHvCuwuS43lGnYhten69ZWqi2QOj/CiDNcKbVqwVoew= golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -896,7 +999,9 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190110163146-51295c7ec13a/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -987,6 +1092,9 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= @@ -1004,6 +1112,8 @@ google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0M google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= @@ -1012,7 +1122,11 @@ google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCID google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20170818010345-ee236bd376b0/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20181107211654-5fc9ac540362/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1043,6 +1157,9 @@ google.golang.org/genproto v0.0.0-20200626011028-ee7919e894b5/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200707001353-8e8330bf89df h1:HWF6nM8ruGdu1K8IXFR+i2oT3YP+iBfZzCbC9zUfcWo= google.golang.org/genproto v0.0.0-20200707001353-8e8330bf89df/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1069,6 +1186,7 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.0 h1:KhgSLlr/moiqjv0qUsSnLvdUL7NH7PHW8aZGn7Jpjko= google.golang.org/protobuf v1.27.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= @@ -1082,6 +1200,7 @@ gopkg.in/cheggaaa/pb.v1 v1.0.28/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qS gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno= gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= @@ -1104,6 +1223,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -1125,3 +1246,5 @@ rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8 rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/vendor/github.com/cheekybits/genny/.gitignore b/vendor/github.com/cheekybits/genny/.gitignore new file mode 100644 index 00000000..c62d148c --- /dev/null +++ b/vendor/github.com/cheekybits/genny/.gitignore @@ -0,0 +1,26 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof + +genny diff --git a/vendor/github.com/cheekybits/genny/.travis.yml b/vendor/github.com/cheekybits/genny/.travis.yml new file mode 100644 index 00000000..78ba5f2d --- /dev/null +++ b/vendor/github.com/cheekybits/genny/.travis.yml @@ -0,0 +1,6 @@ +language: go + +go: + - 1.7 + - 1.8 + - 1.9 diff --git a/vendor/github.com/cheekybits/genny/LICENSE b/vendor/github.com/cheekybits/genny/LICENSE new file mode 100644 index 00000000..519d7f22 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2014 cheekybits + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/vendor/github.com/cheekybits/genny/README.md b/vendor/github.com/cheekybits/genny/README.md new file mode 100644 index 00000000..64a28ac7 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/README.md @@ -0,0 +1,245 @@ +# genny - Generics for Go + +[![Build Status](https://travis-ci.org/cheekybits/genny.svg?branch=master)](https://travis-ci.org/cheekybits/genny) [![GoDoc](https://godoc.org/github.com/cheekybits/genny/parse?status.png)](http://godoc.org/github.com/cheekybits/genny/parse) + +Install: + +``` +go get github.com/cheekybits/genny +``` + +===== + +(pron. Jenny) by Mat Ryer ([@matryer](https://twitter.com/matryer)) and Tyler Bunnell ([@TylerJBunnell](https://twitter.com/TylerJBunnell)). + +Until the Go core team include support for [generics in Go](http://golang.org/doc/faq#generics), `genny` is a code-generation generics solution. It allows you write normal buildable and testable Go code which, when processed by the `genny gen` tool, will replace the generics with specific types. + + * Generic code is valid Go code + * Generic code compiles and can be tested + * Use `stdin` and `stdout` or specify in and out files + * Supports Go 1.4's [go generate](http://tip.golang.org/doc/go1.4#gogenerate) + * Multiple specific types will generate every permutation + * Use `BUILTINS` and `NUMBERS` wildtype to generate specific code for all built-in (and number) Go types + * Function names and comments also get updated + +## Library + +We have started building a [library of common things](https://github.com/cheekybits/gennylib), and you can use `genny get` to generate the specific versions you need. + +For example: `genny get maps/concurrentmap.go "KeyType=BUILTINS ValueType=BUILTINS"` will print out generated code for all types for a concurrent map. Any file in the library may be generated locally in this way using all the same options given to `genny gen`. + +## Usage + +``` +genny [{flags}] gen "{types}" + +gen - generates type specific code from generic code. +get - fetch a generic template from the online library and gen it. + +{flags} - (optional) Command line flags (see below) +{types} - (required) Specific types for each generic type in the source +{types} format: {generic}={specific}[,another][ {generic2}={specific2}] + +Examples: + Generic=Specific + Generic1=Specific1 Generic2=Specific2 + Generic1=Specific1,Specific2 Generic2=Specific3,Specific4 + +Flags: + -in="": file to parse instead of stdin + -out="": file to save output to instead of stdout + -pkg="": package name for generated files +``` + + * Comma separated type lists will generate code for each type + +### Flags + + * `-in` - specify the input file (rather than using stdin) + * `-out` - specify the output file (rather than using stdout) + +### go generate + +To use Go 1.4's `go generate` capability, insert the following comment in your source code file: + +``` +//go:generate genny -in=$GOFILE -out=gen-$GOFILE gen "KeyType=string,int ValueType=string,int" +``` + + * Start the line with `//go:generate ` + * Use the `-in` and `-out` flags to specify the files to work on + * Use the `genny` command as usual after the flags + +Now, running `go generate` (in a shell) for the package will cause the generic versions of the files to be generated. + + * The output file will be overwritten, so it's safe to call `go generate` many times + * Use `$GOFILE` to refer to the current file + * The `//go:generate` line will be removed from the output + +To see a real example of how to use `genny` with `go generate`, look in the [example/go-generate directory](https://github.com/cheekybits/genny/tree/master/examples/go-generate). + +## How it works + +Define your generic types using the special `generic.Type` placeholder type: + +```go +type KeyType generic.Type +type ValueType generic.Type +``` + + * You can use as many as you like + * Give them meaningful names + +Then write the generic code referencing the types as your normally would: + +```go +func SetValueTypeForKeyType(key KeyType, value ValueType) { /* ... */ } +``` + + * Generic type names will also be replaced in comments and function names (see Real example below) + +Since `generic.Type` is a real Go type, your code will compile, and you can even write unit tests against your generic code. + +#### Generating specific versions + +Pass the file through the `genny gen` tool with the specific types as the argument: + +``` +cat generic.go | genny gen "KeyType=string ValueType=interface{}" +``` + +The output will be the complete Go source file with the generic types replaced with the types specified in the arguments. + +## Real example + +Given [this generic Go code](https://github.com/cheekybits/genny/tree/master/examples/queue) which compiles and is tested: + +```go +package queue + +import "github.com/cheekybits/genny/generic" + +// NOTE: this is how easy it is to define a generic type +type Something generic.Type + +// SomethingQueue is a queue of Somethings. +type SomethingQueue struct { + items []Something +} + +func NewSomethingQueue() *SomethingQueue { + return &SomethingQueue{items: make([]Something, 0)} +} +func (q *SomethingQueue) Push(item Something) { + q.items = append(q.items, item) +} +func (q *SomethingQueue) Pop() Something { + item := q.items[0] + q.items = q.items[1:] + return item +} +``` + +When `genny gen` is invoked like this: + +``` +cat source.go | genny gen "Something=string" +``` + +It outputs: + +```go +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package queue + +// StringQueue is a queue of Strings. +type StringQueue struct { + items []string +} + +func NewStringQueue() *StringQueue { + return &StringQueue{items: make([]string, 0)} +} +func (q *StringQueue) Push(item string) { + q.items = append(q.items, item) +} +func (q *StringQueue) Pop() string { + item := q.items[0] + q.items = q.items[1:] + return item +} +``` + +To get a _something_ for every built-in Go type plus one of your own types, you could run: + +``` +cat source.go | genny gen "Something=BUILTINS,*MyType" +``` + +#### More examples + +Check out the [test code files](https://github.com/cheekybits/genny/tree/master/parse/test) for more real examples. + +## Writing test code + +Once you have defined a generic type with some code worth testing: + +```go +package slice + +import ( + "log" + "reflect" + + "github.com/stretchr/gogen/generic" +) + +type MyType generic.Type + +func EnsureMyTypeSlice(objectOrSlice interface{}) []MyType { + log.Printf("%v", reflect.TypeOf(objectOrSlice)) + switch obj := objectOrSlice.(type) { + case []MyType: + log.Println(" returning it untouched") + return obj + case MyType: + log.Println(" wrapping in slice") + return []MyType{obj} + default: + panic("ensure slice needs MyType or []MyType") + } +} +``` + +You can treat it like any normal Go type in your test code: + +```go +func TestEnsureMyTypeSlice(t *testing.T) { + + myType := new(MyType) + slice := EnsureMyTypeSlice(myType) + if assert.NotNil(t, slice) { + assert.Equal(t, slice[0], myType) + } + + slice = EnsureMyTypeSlice(slice) + log.Printf("%#v", slice[0]) + if assert.NotNil(t, slice) { + assert.Equal(t, slice[0], myType) + } + +} +``` + +### Understanding what `generic.Type` is + +Because `generic.Type` is an empty interface type (literally `interface{}`) every other type will be considered to be a `generic.Type` if you are switching on the type of an object. Of course, once the specific versions are generated, this issue goes away but it's worth knowing when you are writing your tests against generic code. + +### Contributions + + * See the [API documentation for the parse package](http://godoc.org/github.com/cheekybits/genny/parse) + * Please do TDD + * All input welcome diff --git a/vendor/github.com/cheekybits/genny/doc.go b/vendor/github.com/cheekybits/genny/doc.go new file mode 100644 index 00000000..4c31e22b --- /dev/null +++ b/vendor/github.com/cheekybits/genny/doc.go @@ -0,0 +1,2 @@ +// Package main is the command line tool for Genny. +package main diff --git a/vendor/github.com/cheekybits/genny/generic/doc.go b/vendor/github.com/cheekybits/genny/generic/doc.go new file mode 100644 index 00000000..3bd6c869 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/generic/doc.go @@ -0,0 +1,2 @@ +// Package generic contains the generic marker types. +package generic diff --git a/vendor/github.com/cheekybits/genny/generic/generic.go b/vendor/github.com/cheekybits/genny/generic/generic.go new file mode 100644 index 00000000..04a2306c --- /dev/null +++ b/vendor/github.com/cheekybits/genny/generic/generic.go @@ -0,0 +1,13 @@ +package generic + +// Type is the placeholder type that indicates a generic value. +// When genny is executed, variables of this type will be replaced with +// references to the specific types. +// var GenericType generic.Type +type Type interface{} + +// Number is the placehoder type that indiccates a generic numerical value. +// When genny is executed, variables of this type will be replaced with +// references to the specific types. +// var GenericType generic.Number +type Number float64 diff --git a/vendor/github.com/cheekybits/genny/main.go b/vendor/github.com/cheekybits/genny/main.go new file mode 100644 index 00000000..fe06a6c0 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/main.go @@ -0,0 +1,154 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "strings" + + "github.com/cheekybits/genny/out" + "github.com/cheekybits/genny/parse" +) + +/* + + source | genny gen [-in=""] [-out=""] [-pkg=""] "KeyType=string,int ValueType=string,int" + +*/ + +const ( + _ = iota + exitcodeInvalidArgs + exitcodeInvalidTypeSet + exitcodeStdinFailed + exitcodeGenFailed + exitcodeGetFailed + exitcodeSourceFileInvalid + exitcodeDestFileFailed +) + +func main() { + var ( + in = flag.String("in", "", "file to parse instead of stdin") + out = flag.String("out", "", "file to save output to instead of stdout") + pkgName = flag.String("pkg", "", "package name for generated files") + prefix = "https://github.com/metabition/gennylib/raw/master/" + ) + flag.Parse() + args := flag.Args() + + if len(args) < 2 { + usage() + os.Exit(exitcodeInvalidArgs) + } + + if strings.ToLower(args[0]) != "gen" && strings.ToLower(args[0]) != "get" { + usage() + os.Exit(exitcodeInvalidArgs) + } + + // parse the typesets + var setsArg = args[1] + if strings.ToLower(args[0]) == "get" { + setsArg = args[2] + } + typeSets, err := parse.TypeSet(setsArg) + if err != nil { + fatal(exitcodeInvalidTypeSet, err) + } + + outWriter := newWriter(*out) + + if strings.ToLower(args[0]) == "get" { + if len(args) != 3 { + fmt.Println("not enough arguments to get") + usage() + os.Exit(exitcodeInvalidArgs) + } + r, err := http.Get(prefix + args[1]) + if err != nil { + fatal(exitcodeGetFailed, err) + } + b, err := ioutil.ReadAll(r.Body) + if err != nil { + fatal(exitcodeGetFailed, err) + } + r.Body.Close() + br := bytes.NewReader(b) + err = gen(*in, *pkgName, br, typeSets, outWriter) + } else if len(*in) > 0 { + var file *os.File + file, err = os.Open(*in) + if err != nil { + fatal(exitcodeSourceFileInvalid, err) + } + defer file.Close() + err = gen(*in, *pkgName, file, typeSets, outWriter) + } else { + var source []byte + source, err = ioutil.ReadAll(os.Stdin) + if err != nil { + fatal(exitcodeStdinFailed, err) + } + reader := bytes.NewReader(source) + err = gen("stdin", *pkgName, reader, typeSets, outWriter) + } + + // do the work + if err != nil { + fatal(exitcodeGenFailed, err) + } + +} + +func usage() { + fmt.Fprintln(os.Stderr, `usage: genny [{flags}] gen "{types}" + +gen - generates type specific code from generic code. +get - fetch a generic template from the online library and gen it. + +{flags} - (optional) Command line flags (see below) +{types} - (required) Specific types for each generic type in the source +{types} format: {generic}={specific}[,another][ {generic2}={specific2}] + +Examples: + Generic=Specific + Generic1=Specific1 Generic2=Specific2 + Generic1=Specific1,Specific2 Generic2=Specific3,Specific4 + +Flags:`) + flag.PrintDefaults() +} + +func newWriter(fileName string) io.Writer { + if fileName == "" { + return os.Stdout + } + lf := &out.LazyFile{FileName: fileName} + defer lf.Close() + return lf +} + +func fatal(code int, a ...interface{}) { + fmt.Println(a...) + os.Exit(code) +} + +// gen performs the generic generation. +func gen(filename, pkgName string, in io.ReadSeeker, typesets []map[string]string, out io.Writer) error { + + var output []byte + var err error + + output, err = parse.Generics(filename, pkgName, in, typesets) + if err != nil { + return err + } + + out.Write(output) + return nil +} diff --git a/vendor/github.com/cheekybits/genny/out/lazy_file.go b/vendor/github.com/cheekybits/genny/out/lazy_file.go new file mode 100644 index 00000000..7c8815f5 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/out/lazy_file.go @@ -0,0 +1,38 @@ +package out + +import ( + "os" + "path" +) + +// LazyFile is an io.WriteCloser which defers creation of the file it is supposed to write in +// till the first call to its write function in order to prevent creation of file, if no write +// is supposed to happen. +type LazyFile struct { + // FileName is path to the file to which genny will write. + FileName string + file *os.File +} + +// Close closes the file if it is created. Returns nil if no file is created. +func (lw *LazyFile) Close() error { + if lw.file != nil { + return lw.file.Close() + } + return nil +} + +// Write writes to the specified file and creates the file first time it is called. +func (lw *LazyFile) Write(p []byte) (int, error) { + if lw.file == nil { + err := os.MkdirAll(path.Dir(lw.FileName), 0755) + if err != nil { + return 0, err + } + lw.file, err = os.Create(lw.FileName) + if err != nil { + return 0, err + } + } + return lw.file.Write(p) +} diff --git a/vendor/github.com/cheekybits/genny/parse/builtins.go b/vendor/github.com/cheekybits/genny/parse/builtins.go new file mode 100644 index 00000000..e0299544 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/parse/builtins.go @@ -0,0 +1,41 @@ +package parse + +// Builtins contains a slice of all built-in Go types. +var Builtins = []string{ + "bool", + "byte", + "complex128", + "complex64", + "error", + "float32", + "float64", + "int", + "int16", + "int32", + "int64", + "int8", + "rune", + "string", + "uint", + "uint16", + "uint32", + "uint64", + "uint8", + "uintptr", +} + +// Numbers contains a slice of all built-in number types. +var Numbers = []string{ + "float32", + "float64", + "int", + "int16", + "int32", + "int64", + "int8", + "uint", + "uint16", + "uint32", + "uint64", + "uint8", +} diff --git a/vendor/github.com/cheekybits/genny/parse/doc.go b/vendor/github.com/cheekybits/genny/parse/doc.go new file mode 100644 index 00000000..1be4fed8 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/parse/doc.go @@ -0,0 +1,14 @@ +// Package parse contains the generic code generation capabilities +// that power genny. +// +// genny gen "{types}" +// +// gen - generates type specific code (to stdout) from generic code (via stdin) +// +// {types} - (required) Specific types for each generic type in the source +// {types} format: {generic}={specific}[,another][ {generic2}={specific2}] +// Examples: +// Generic=Specific +// Generic1=Specific1 Generic2=Specific2 +// Generic1=Specific1,Specific2 Generic2=Specific3,Specific4 +package parse diff --git a/vendor/github.com/cheekybits/genny/parse/errors.go b/vendor/github.com/cheekybits/genny/parse/errors.go new file mode 100644 index 00000000..ab812bf9 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/parse/errors.go @@ -0,0 +1,47 @@ +package parse + +import ( + "errors" +) + +// errMissingSpecificType represents an error when a generic type is not +// satisfied by a specific type. +type errMissingSpecificType struct { + GenericType string +} + +// Error gets a human readable string describing this error. +func (e errMissingSpecificType) Error() string { + return "Missing specific type for '" + e.GenericType + "' generic type" +} + +// errImports represents an error from goimports. +type errImports struct { + Err error +} + +// Error gets a human readable string describing this error. +func (e errImports) Error() string { + return "Failed to goimports the generated code: " + e.Err.Error() +} + +// errSource represents an error with the source file. +type errSource struct { + Err error +} + +// Error gets a human readable string describing this error. +func (e errSource) Error() string { + return "Failed to parse source file: " + e.Err.Error() +} + +type errBadTypeArgs struct { + Message string + Arg string +} + +func (e errBadTypeArgs) Error() string { + return "\"" + e.Arg + "\" is bad: " + e.Message +} + +var errMissingTypeInformation = errors.New("No type arguments were specified and no \"// +gogen\" tag was found in the source.") diff --git a/vendor/github.com/cheekybits/genny/parse/parse.go b/vendor/github.com/cheekybits/genny/parse/parse.go new file mode 100644 index 00000000..08eb48b1 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/parse/parse.go @@ -0,0 +1,298 @@ +package parse + +import ( + "bufio" + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/scanner" + "go/token" + "io" + "os" + "strings" + "unicode" + + "golang.org/x/tools/imports" +) + +var header = []byte(` + +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +`) + +var ( + packageKeyword = []byte("package") + importKeyword = []byte("import") + openBrace = []byte("(") + closeBrace = []byte(")") + genericPackage = "generic" + genericType = "generic.Type" + genericNumber = "generic.Number" + linefeed = "\r\n" +) +var unwantedLinePrefixes = [][]byte{ + []byte("//go:generate genny "), +} + +func subIntoLiteral(lit, typeTemplate, specificType string) string { + if lit == typeTemplate { + return specificType + } + if !strings.Contains(lit, typeTemplate) { + return lit + } + specificLg := wordify(specificType, true) + specificSm := wordify(specificType, false) + result := strings.Replace(lit, typeTemplate, specificLg, -1) + if strings.HasPrefix(result, specificLg) && !isExported(lit) { + return strings.Replace(result, specificLg, specificSm, 1) + } + return result +} + +func subTypeIntoComment(line, typeTemplate, specificType string) string { + var subbed string + for _, w := range strings.Fields(line) { + subbed = subbed + subIntoLiteral(w, typeTemplate, specificType) + " " + } + return subbed +} + +// Does the heavy lifting of taking a line of our code and +// sbustituting a type into there for our generic type +func subTypeIntoLine(line, typeTemplate, specificType string) string { + src := []byte(line) + var s scanner.Scanner + fset := token.NewFileSet() + file := fset.AddFile("", fset.Base(), len(src)) + s.Init(file, src, nil, scanner.ScanComments) + output := "" + for { + _, tok, lit := s.Scan() + if tok == token.EOF { + break + } else if tok == token.COMMENT { + subbed := subTypeIntoComment(lit, typeTemplate, specificType) + output = output + subbed + " " + } else if tok.IsLiteral() { + subbed := subIntoLiteral(lit, typeTemplate, specificType) + output = output + subbed + " " + } else { + output = output + tok.String() + " " + } + } + return output +} + +// typeSet looks like "KeyType: int, ValueType: string" +func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]string) ([]byte, error) { + + // ensure we are at the beginning of the file + in.Seek(0, os.SEEK_SET) + + // parse the source file + fs := token.NewFileSet() + file, err := parser.ParseFile(fs, filename, in, 0) + if err != nil { + return nil, &errSource{Err: err} + } + + // make sure every generic.Type is represented in the types + // argument. + for _, decl := range file.Decls { + switch it := decl.(type) { + case *ast.GenDecl: + for _, spec := range it.Specs { + ts, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + switch tt := ts.Type.(type) { + case *ast.SelectorExpr: + if name, ok := tt.X.(*ast.Ident); ok { + if name.Name == genericPackage { + if _, ok := typeSet[ts.Name.Name]; !ok { + return nil, &errMissingSpecificType{GenericType: ts.Name.Name} + } + } + } + } + } + } + } + + in.Seek(0, os.SEEK_SET) + + var buf bytes.Buffer + + comment := "" + scanner := bufio.NewScanner(in) + for scanner.Scan() { + + line := scanner.Text() + + // does this line contain generic.Type? + if strings.Contains(line, genericType) || strings.Contains(line, genericNumber) { + comment = "" + continue + } + + for t, specificType := range typeSet { + if strings.Contains(line, t) { + newLine := subTypeIntoLine(line, t, specificType) + line = newLine + } + } + + if comment != "" { + buf.WriteString(makeLine(comment)) + comment = "" + } + + // is this line a comment? + // TODO: should we handle /* */ comments? + if strings.HasPrefix(line, "//") { + // record this line to print later + comment = line + continue + } + + // write the line + buf.WriteString(makeLine(line)) + } + + // write it out + return buf.Bytes(), nil +} + +// Generics parses the source file and generates the bytes replacing the +// generic types for the keys map with the specific types (its value). +func Generics(filename, pkgName string, in io.ReadSeeker, typeSets []map[string]string) ([]byte, error) { + + totalOutput := header + + for _, typeSet := range typeSets { + + // generate the specifics + parsed, err := generateSpecific(filename, in, typeSet) + if err != nil { + return nil, err + } + + totalOutput = append(totalOutput, parsed...) + + } + + // clean up the code line by line + packageFound := false + insideImportBlock := false + var cleanOutputLines []string + scanner := bufio.NewScanner(bytes.NewReader(totalOutput)) + for scanner.Scan() { + + // end of imports block? + if insideImportBlock { + if bytes.HasSuffix(scanner.Bytes(), closeBrace) { + insideImportBlock = false + } + continue + } + + if bytes.HasPrefix(scanner.Bytes(), packageKeyword) { + if packageFound { + continue + } else { + packageFound = true + } + } else if bytes.HasPrefix(scanner.Bytes(), importKeyword) { + if bytes.HasSuffix(scanner.Bytes(), openBrace) { + insideImportBlock = true + } + continue + } + + // check all unwantedLinePrefixes - and skip them + skipline := false + for _, prefix := range unwantedLinePrefixes { + if bytes.HasPrefix(scanner.Bytes(), prefix) { + skipline = true + continue + } + } + + if skipline { + continue + } + + cleanOutputLines = append(cleanOutputLines, makeLine(scanner.Text())) + } + + cleanOutput := strings.Join(cleanOutputLines, "") + + output := []byte(cleanOutput) + var err error + + // change package name + if pkgName != "" { + output = changePackage(bytes.NewReader([]byte(output)), pkgName) + } + // fix the imports + output, err = imports.Process(filename, output, nil) + if err != nil { + return nil, &errImports{Err: err} + } + + return output, nil +} + +func makeLine(s string) string { + return fmt.Sprintln(strings.TrimRight(s, linefeed)) +} + +// isAlphaNumeric gets whether the rune is alphanumeric or _. +func isAlphaNumeric(r rune) bool { + return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r) +} + +// wordify turns a type into a nice word for function and type +// names etc. +func wordify(s string, exported bool) string { + s = strings.TrimRight(s, "{}") + s = strings.TrimLeft(s, "*&") + s = strings.Replace(s, ".", "", -1) + if !exported { + return s + } + return strings.ToUpper(string(s[0])) + s[1:] +} + +func changePackage(r io.Reader, pkgName string) []byte { + var out bytes.Buffer + sc := bufio.NewScanner(r) + done := false + + for sc.Scan() { + s := sc.Text() + + if !done && strings.HasPrefix(s, "package") { + parts := strings.Split(s, " ") + parts[1] = pkgName + s = strings.Join(parts, " ") + done = true + } + + fmt.Fprintln(&out, s) + } + return out.Bytes() +} + +func isExported(lit string) bool { + if len(lit) == 0 { + return false + } + return unicode.IsUpper(rune(lit[0])) +} diff --git a/vendor/github.com/cheekybits/genny/parse/typesets.go b/vendor/github.com/cheekybits/genny/parse/typesets.go new file mode 100644 index 00000000..c30b9728 --- /dev/null +++ b/vendor/github.com/cheekybits/genny/parse/typesets.go @@ -0,0 +1,89 @@ +package parse + +import "strings" + +const ( + typeSep = " " + keyValueSep = "=" + valuesSep = "," + builtins = "BUILTINS" + numbers = "NUMBERS" +) + +// TypeSet turns a type string into a []map[string]string +// that can be given to parse.Generics for it to do its magic. +// +// Acceptable args are: +// +// Person=man +// Person=man Animal=dog +// Person=man Animal=dog Animal2=cat +// Person=man,woman Animal=dog,cat +// Person=man,woman,child Animal=dog,cat Place=london,paris +func TypeSet(arg string) ([]map[string]string, error) { + + types := make(map[string][]string) + var keys []string + for _, pair := range strings.Split(arg, typeSep) { + segs := strings.Split(pair, keyValueSep) + if len(segs) != 2 { + return nil, &errBadTypeArgs{Arg: arg, Message: "Generic=Specific expected"} + } + key := segs[0] + keys = append(keys, key) + types[key] = make([]string, 0) + for _, t := range strings.Split(segs[1], valuesSep) { + if t == builtins { + types[key] = append(types[key], Builtins...) + } else if t == numbers { + types[key] = append(types[key], Numbers...) + } else { + types[key] = append(types[key], t) + } + } + } + + cursors := make(map[string]int) + for _, key := range keys { + cursors[key] = 0 + } + + outChan := make(chan map[string]string) + go func() { + buildTypeSet(keys, 0, cursors, types, outChan) + close(outChan) + }() + + var typeSets []map[string]string + for typeSet := range outChan { + typeSets = append(typeSets, typeSet) + } + + return typeSets, nil + +} + +func buildTypeSet(keys []string, keyI int, cursors map[string]int, types map[string][]string, out chan<- map[string]string) { + key := keys[keyI] + for cursors[key] < len(types[key]) { + if keyI < len(keys)-1 { + buildTypeSet(keys, keyI+1, copycursors(cursors), types, out) + } else { + // build the typeset for this combination + ts := make(map[string]string) + for k, vals := range types { + ts[k] = vals[cursors[k]] + } + out <- ts + } + cursors[key]++ + } +} + +func copycursors(source map[string]int) map[string]int { + copy := make(map[string]int) + for k, v := range source { + copy[k] = v + } + return copy +} diff --git a/vendor/github.com/go-task/slim-sprig/.editorconfig b/vendor/github.com/go-task/slim-sprig/.editorconfig new file mode 100644 index 00000000..b0c95367 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/.editorconfig @@ -0,0 +1,14 @@ +# editorconfig.org + +root = true + +[*] +insert_final_newline = true +charset = utf-8 +trim_trailing_whitespace = true +indent_style = tab +indent_size = 8 + +[*.{md,yml,yaml,json}] +indent_style = space +indent_size = 2 diff --git a/vendor/github.com/go-task/slim-sprig/.gitattributes b/vendor/github.com/go-task/slim-sprig/.gitattributes new file mode 100644 index 00000000..176a458f --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/.gitattributes @@ -0,0 +1 @@ +* text=auto diff --git a/vendor/github.com/go-task/slim-sprig/.gitignore b/vendor/github.com/go-task/slim-sprig/.gitignore new file mode 100644 index 00000000..5e3002f8 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/.gitignore @@ -0,0 +1,2 @@ +vendor/ +/.glide diff --git a/vendor/github.com/go-task/slim-sprig/CHANGELOG.md b/vendor/github.com/go-task/slim-sprig/CHANGELOG.md new file mode 100644 index 00000000..61d8ebff --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/CHANGELOG.md @@ -0,0 +1,364 @@ +# Changelog + +## Release 3.2.0 (2020-12-14) + +### Added + +- #211: Added randInt function (thanks @kochurovro) +- #223: Added fromJson and mustFromJson functions (thanks @mholt) +- #242: Added a bcrypt function (thanks @robbiet480) +- #253: Added randBytes function (thanks @MikaelSmith) +- #254: Added dig function for dicts (thanks @nyarly) +- #257: Added regexQuoteMeta for quoting regex metadata (thanks @rheaton) +- #261: Added filepath functions osBase, osDir, osExt, osClean, osIsAbs (thanks @zugl) +- #268: Added and and all functions for testing conditions (thanks @phuslu) +- #181: Added float64 arithmetic addf, add1f, subf, divf, mulf, maxf, and minf + (thanks @andrewmostello) +- #265: Added chunk function to split array into smaller arrays (thanks @karelbilek) +- #270: Extend certificate functions to handle non-RSA keys + add support for + ed25519 keys (thanks @misberner) + +### Changed + +- Removed testing and support for Go 1.12. ed25519 support requires Go 1.13 or newer +- Using semver 3.1.1 and mergo 0.3.11 + +### Fixed + +- #249: Fix htmlDateInZone example (thanks @spawnia) + +NOTE: The dependency github.com/imdario/mergo reverted the breaking change in +0.3.9 via 0.3.10 release. + +## Release 3.1.0 (2020-04-16) + +NOTE: The dependency github.com/imdario/mergo made a behavior change in 0.3.9 +that impacts sprig functionality. Do not use sprig with a version newer than 0.3.8. + +### Added + +- #225: Added support for generating htpasswd hash (thanks @rustycl0ck) +- #224: Added duration filter (thanks @frebib) +- #205: Added `seq` function (thanks @thadc23) + +### Changed + +- #203: Unlambda functions with correct signature (thanks @muesli) +- #236: Updated the license formatting for GitHub display purposes +- #238: Updated package dependency versions. Note, mergo not updated to 0.3.9 + as it causes a breaking change for sprig. That issue is tracked at + https://github.com/imdario/mergo/issues/139 + +### Fixed + +- #229: Fix `seq` example in docs (thanks @kalmant) + +## Release 3.0.2 (2019-12-13) + +### Fixed + +- #220: Updating to semver v3.0.3 to fix issue with <= ranges +- #218: fix typo elyptical->elliptic in ecdsa key description (thanks @laverya) + +## Release 3.0.1 (2019-12-08) + +### Fixed + +- #212: Updated semver fixing broken constraint checking with ^0.0 + +## Release 3.0.0 (2019-10-02) + +### Added + +- #187: Added durationRound function (thanks @yjp20) +- #189: Added numerous template functions that return errors rather than panic (thanks @nrvnrvn) +- #193: Added toRawJson support (thanks @Dean-Coakley) +- #197: Added get support to dicts (thanks @Dean-Coakley) + +### Changed + +- #186: Moving dependency management to Go modules +- #186: Updated semver to v3. This has changes in the way ^ is handled +- #194: Updated documentation on merging and how it copies. Added example using deepCopy +- #196: trunc now supports negative values (thanks @Dean-Coakley) + +## Release 2.22.0 (2019-10-02) + +### Added + +- #173: Added getHostByName function to resolve dns names to ips (thanks @fcgravalos) +- #195: Added deepCopy function for use with dicts + +### Changed + +- Updated merge and mergeOverwrite documentation to explain copying and how to + use deepCopy with it + +## Release 2.21.0 (2019-09-18) + +### Added + +- #122: Added encryptAES/decryptAES functions (thanks @n0madic) +- #128: Added toDecimal support (thanks @Dean-Coakley) +- #169: Added list contcat (thanks @astorath) +- #174: Added deepEqual function (thanks @bonifaido) +- #170: Added url parse and join functions (thanks @astorath) + +### Changed + +- #171: Updated glide config for Google UUID to v1 and to add ranges to semver and testify + +### Fixed + +- #172: Fix semver wildcard example (thanks @piepmatz) +- #175: Fix dateInZone doc example (thanks @s3than) + +## Release 2.20.0 (2019-06-18) + +### Added + +- #164: Adding function to get unix epoch for a time (@mattfarina) +- #166: Adding tests for date_in_zone (@mattfarina) + +### Changed + +- #144: Fix function comments based on best practices from Effective Go (@CodeLingoTeam) +- #150: Handles pointer type for time.Time in "htmlDate" (@mapreal19) +- #161, #157, #160, #153, #158, #156, #155, #159, #152 documentation updates (@badeadan) + +### Fixed + +## Release 2.19.0 (2019-03-02) + +IMPORTANT: This release reverts a change from 2.18.0 + +In the previous release (2.18), we prematurely merged a partial change to the crypto functions that led to creating two sets of crypto functions (I blame @technosophos -- since that's me). This release rolls back that change, and does what was originally intended: It alters the existing crypto functions to use secure random. + +We debated whether this classifies as a change worthy of major revision, but given the proximity to the last release, we have decided that treating 2.18 as a faulty release is the correct course of action. We apologize for any inconvenience. + +### Changed + +- Fix substr panic 35fb796 (Alexey igrychev) +- Remove extra period 1eb7729 (Matthew Lorimor) +- Make random string functions use crypto by default 6ceff26 (Matthew Lorimor) +- README edits/fixes/suggestions 08fe136 (Lauri Apple) + + +## Release 2.18.0 (2019-02-12) + +### Added + +- Added mergeOverwrite function +- cryptographic functions that use secure random (see fe1de12) + +### Changed + +- Improve documentation of regexMatch function, resolves #139 90b89ce (Jan Tagscherer) +- Handle has for nil list 9c10885 (Daniel Cohen) +- Document behaviour of mergeOverwrite fe0dbe9 (Lukas Rieder) +- doc: adds missing documentation. 4b871e6 (Fernandez Ludovic) +- Replace outdated goutils imports 01893d2 (Matthew Lorimor) +- Surface crypto secure random strings from goutils fe1de12 (Matthew Lorimor) +- Handle untyped nil values as paramters to string functions 2b2ec8f (Morten Torkildsen) + +### Fixed + +- Fix dict merge issue and provide mergeOverwrite .dst .src1 to overwrite from src -> dst 4c59c12 (Lukas Rieder) +- Fix substr var names and comments d581f80 (Dean Coakley) +- Fix substr documentation 2737203 (Dean Coakley) + +## Release 2.17.1 (2019-01-03) + +### Fixed + +The 2.17.0 release did not have a version pinned for xstrings, which caused compilation failures when xstrings < 1.2 was used. This adds the correct version string to glide.yaml. + +## Release 2.17.0 (2019-01-03) + +### Added + +- adds alder32sum function and test 6908fc2 (marshallford) +- Added kebabcase function ca331a1 (Ilyes512) + +### Changed + +- Update goutils to 1.1.0 4e1125d (Matt Butcher) + +### Fixed + +- Fix 'has' documentation e3f2a85 (dean-coakley) +- docs(dict): fix typo in pick example dc424f9 (Dustin Specker) +- fixes spelling errors... not sure how that happened 4cf188a (marshallford) + +## Release 2.16.0 (2018-08-13) + +### Added + +- add splitn function fccb0b0 (Helgi Þorbjörnsson) +- Add slice func df28ca7 (gongdo) +- Generate serial number a3bdffd (Cody Coons) +- Extract values of dict with values function df39312 (Lawrence Jones) + +### Changed + +- Modify panic message for list.slice ae38335 (gongdo) +- Minor improvement in code quality - Removed an unreachable piece of code at defaults.go#L26:6 - Resolve formatting issues. 5834241 (Abhishek Kashyap) +- Remove duplicated documentation 1d97af1 (Matthew Fisher) +- Test on go 1.11 49df809 (Helgi Þormar Þorbjörnsson) + +### Fixed + +- Fix file permissions c5f40b5 (gongdo) +- Fix example for buildCustomCert 7779e0d (Tin Lam) + +## Release 2.15.0 (2018-04-02) + +### Added + +- #68 and #69: Add json helpers to docs (thanks @arunvelsriram) +- #66: Add ternary function (thanks @binoculars) +- #67: Allow keys function to take multiple dicts (thanks @binoculars) +- #89: Added sha1sum to crypto function (thanks @benkeil) +- #81: Allow customizing Root CA that used by genSignedCert (thanks @chenzhiwei) +- #92: Add travis testing for go 1.10 +- #93: Adding appveyor config for windows testing + +### Changed + +- #90: Updating to more recent dependencies +- #73: replace satori/go.uuid with google/uuid (thanks @petterw) + +### Fixed + +- #76: Fixed documentation typos (thanks @Thiht) +- Fixed rounding issue on the `ago` function. Note, the removes support for Go 1.8 and older + +## Release 2.14.1 (2017-12-01) + +### Fixed + +- #60: Fix typo in function name documentation (thanks @neil-ca-moore) +- #61: Removing line with {{ due to blocking github pages genertion +- #64: Update the list functions to handle int, string, and other slices for compatibility + +## Release 2.14.0 (2017-10-06) + +This new version of Sprig adds a set of functions for generating and working with SSL certificates. + +- `genCA` generates an SSL Certificate Authority +- `genSelfSignedCert` generates an SSL self-signed certificate +- `genSignedCert` generates an SSL certificate and key based on a given CA + +## Release 2.13.0 (2017-09-18) + +This release adds new functions, including: + +- `regexMatch`, `regexFindAll`, `regexFind`, `regexReplaceAll`, `regexReplaceAllLiteral`, and `regexSplit` to work with regular expressions +- `floor`, `ceil`, and `round` math functions +- `toDate` converts a string to a date +- `nindent` is just like `indent` but also prepends a new line +- `ago` returns the time from `time.Now` + +### Added + +- #40: Added basic regex functionality (thanks @alanquillin) +- #41: Added ceil floor and round functions (thanks @alanquillin) +- #48: Added toDate function (thanks @andreynering) +- #50: Added nindent function (thanks @binoculars) +- #46: Added ago function (thanks @slayer) + +### Changed + +- #51: Updated godocs to include new string functions (thanks @curtisallen) +- #49: Added ability to merge multiple dicts (thanks @binoculars) + +## Release 2.12.0 (2017-05-17) + +- `snakecase`, `camelcase`, and `shuffle` are three new string functions +- `fail` allows you to bail out of a template render when conditions are not met + +## Release 2.11.0 (2017-05-02) + +- Added `toJson` and `toPrettyJson` +- Added `merge` +- Refactored documentation + +## Release 2.10.0 (2017-03-15) + +- Added `semver` and `semverCompare` for Semantic Versions +- `list` replaces `tuple` +- Fixed issue with `join` +- Added `first`, `last`, `intial`, `rest`, `prepend`, `append`, `toString`, `toStrings`, `sortAlpha`, `reverse`, `coalesce`, `pluck`, `pick`, `compact`, `keys`, `omit`, `uniq`, `has`, `without` + +## Release 2.9.0 (2017-02-23) + +- Added `splitList` to split a list +- Added crypto functions of `genPrivateKey` and `derivePassword` + +## Release 2.8.0 (2016-12-21) + +- Added access to several path functions (`base`, `dir`, `clean`, `ext`, and `abs`) +- Added functions for _mutating_ dictionaries (`set`, `unset`, `hasKey`) + +## Release 2.7.0 (2016-12-01) + +- Added `sha256sum` to generate a hash of an input +- Added functions to convert a numeric or string to `int`, `int64`, `float64` + +## Release 2.6.0 (2016-10-03) + +- Added a `uuidv4` template function for generating UUIDs inside of a template. + +## Release 2.5.0 (2016-08-19) + +- New `trimSuffix`, `trimPrefix`, `hasSuffix`, and `hasPrefix` functions +- New aliases have been added for a few functions that didn't follow the naming conventions (`trimAll` and `abbrevBoth`) +- `trimall` and `abbrevboth` (notice the case) are deprecated and will be removed in 3.0.0 + +## Release 2.4.0 (2016-08-16) + +- Adds two functions: `until` and `untilStep` + +## Release 2.3.0 (2016-06-21) + +- cat: Concatenate strings with whitespace separators. +- replace: Replace parts of a string: `replace " " "-" "Me First"` renders "Me-First" +- plural: Format plurals: `len "foo" | plural "one foo" "many foos"` renders "many foos" +- indent: Indent blocks of text in a way that is sensitive to "\n" characters. + +## Release 2.2.0 (2016-04-21) + +- Added a `genPrivateKey` function (Thanks @bacongobbler) + +## Release 2.1.0 (2016-03-30) + +- `default` now prints the default value when it does not receive a value down the pipeline. It is much safer now to do `{{.Foo | default "bar"}}`. +- Added accessors for "hermetic" functions. These return only functions that, when given the same input, produce the same output. + +## Release 2.0.0 (2016-03-29) + +Because we switched from `int` to `int64` as the return value for all integer math functions, the library's major version number has been incremented. + +- `min` complements `max` (formerly `biggest`) +- `empty` indicates that a value is the empty value for its type +- `tuple` creates a tuple inside of a template: `{{$t := tuple "a", "b" "c"}}` +- `dict` creates a dictionary inside of a template `{{$d := dict "key1" "val1" "key2" "val2"}}` +- Date formatters have been added for HTML dates (as used in `date` input fields) +- Integer math functions can convert from a number of types, including `string` (via `strconv.ParseInt`). + +## Release 1.2.0 (2016-02-01) + +- Added quote and squote +- Added b32enc and b32dec +- add now takes varargs +- biggest now takes varargs + +## Release 1.1.0 (2015-12-29) + +- Added #4: Added contains function. strings.Contains, but with the arguments + switched to simplify common pipelines. (thanks krancour) +- Added Travis-CI testing support + +## Release 1.0.0 (2015-12-23) + +- Initial release diff --git a/vendor/github.com/go-task/slim-sprig/LICENSE.txt b/vendor/github.com/go-task/slim-sprig/LICENSE.txt new file mode 100644 index 00000000..f311b1ea --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/LICENSE.txt @@ -0,0 +1,19 @@ +Copyright (C) 2013-2020 Masterminds + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/go-task/slim-sprig/README.md b/vendor/github.com/go-task/slim-sprig/README.md new file mode 100644 index 00000000..72579471 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/README.md @@ -0,0 +1,73 @@ +# Slim-Sprig: Template functions for Go templates [![GoDoc](https://godoc.org/github.com/go-task/slim-sprig?status.svg)](https://godoc.org/github.com/go-task/slim-sprig) [![Go Report Card](https://goreportcard.com/badge/github.com/go-task/slim-sprig)](https://goreportcard.com/report/github.com/go-task/slim-sprig) + +Slim-Sprig is a fork of [Sprig](https://github.com/Masterminds/sprig), but with +all functions that depend on external (non standard library) or crypto packages +removed. +The reason for this is to make this library more lightweight. Most of these +functions (specially crypto ones) are not needed on most apps, but costs a lot +in terms of binary size and compilation time. + +## Usage + +**Template developers**: Please use Slim-Sprig's [function documentation](https://go-task.github.io/slim-sprig/) for +detailed instructions and code snippets for the >100 template functions available. + +**Go developers**: If you'd like to include Slim-Sprig as a library in your program, +our API documentation is available [at GoDoc.org](http://godoc.org/github.com/go-task/slim-sprig). + +For standard usage, read on. + +### Load the Slim-Sprig library + +To load the Slim-Sprig `FuncMap`: + +```go + +import ( + "html/template" + + "github.com/go-task/slim-sprig" +) + +// This example illustrates that the FuncMap *must* be set before the +// templates themselves are loaded. +tpl := template.Must( + template.New("base").Funcs(sprig.FuncMap()).ParseGlob("*.html") +) +``` + +### Calling the functions inside of templates + +By convention, all functions are lowercase. This seems to follow the Go +idiom for template functions (as opposed to template methods, which are +TitleCase). For example, this: + +``` +{{ "hello!" | upper | repeat 5 }} +``` + +produces this: + +``` +HELLO!HELLO!HELLO!HELLO!HELLO! +``` + +## Principles Driving Our Function Selection + +We followed these principles to decide which functions to add and how to implement them: + +- Use template functions to build layout. The following + types of operations are within the domain of template functions: + - Formatting + - Layout + - Simple type conversions + - Utilities that assist in handling common formatting and layout needs (e.g. arithmetic) +- Template functions should not return errors unless there is no way to print + a sensible value. For example, converting a string to an integer should not + produce an error if conversion fails. Instead, it should display a default + value. +- Simple math is necessary for grid layouts, pagers, and so on. Complex math + (anything other than arithmetic) should be done outside of templates. +- Template functions only deal with the data passed into them. They never retrieve + data from a source. +- Finally, do not override core Go template functions. diff --git a/vendor/github.com/go-task/slim-sprig/Taskfile.yml b/vendor/github.com/go-task/slim-sprig/Taskfile.yml new file mode 100644 index 00000000..cdcfd223 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/Taskfile.yml @@ -0,0 +1,12 @@ +# https://taskfile.dev + +version: '2' + +tasks: + default: + cmds: + - task: test + + test: + cmds: + - go test -v . diff --git a/vendor/github.com/go-task/slim-sprig/crypto.go b/vendor/github.com/go-task/slim-sprig/crypto.go new file mode 100644 index 00000000..d06e516d --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/crypto.go @@ -0,0 +1,24 @@ +package sprig + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "fmt" + "hash/adler32" +) + +func sha256sum(input string) string { + hash := sha256.Sum256([]byte(input)) + return hex.EncodeToString(hash[:]) +} + +func sha1sum(input string) string { + hash := sha1.Sum([]byte(input)) + return hex.EncodeToString(hash[:]) +} + +func adler32sum(input string) string { + hash := adler32.Checksum([]byte(input)) + return fmt.Sprintf("%d", hash) +} diff --git a/vendor/github.com/go-task/slim-sprig/date.go b/vendor/github.com/go-task/slim-sprig/date.go new file mode 100644 index 00000000..ed022dda --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/date.go @@ -0,0 +1,152 @@ +package sprig + +import ( + "strconv" + "time" +) + +// Given a format and a date, format the date string. +// +// Date can be a `time.Time` or an `int, int32, int64`. +// In the later case, it is treated as seconds since UNIX +// epoch. +func date(fmt string, date interface{}) string { + return dateInZone(fmt, date, "Local") +} + +func htmlDate(date interface{}) string { + return dateInZone("2006-01-02", date, "Local") +} + +func htmlDateInZone(date interface{}, zone string) string { + return dateInZone("2006-01-02", date, zone) +} + +func dateInZone(fmt string, date interface{}, zone string) string { + var t time.Time + switch date := date.(type) { + default: + t = time.Now() + case time.Time: + t = date + case *time.Time: + t = *date + case int64: + t = time.Unix(date, 0) + case int: + t = time.Unix(int64(date), 0) + case int32: + t = time.Unix(int64(date), 0) + } + + loc, err := time.LoadLocation(zone) + if err != nil { + loc, _ = time.LoadLocation("UTC") + } + + return t.In(loc).Format(fmt) +} + +func dateModify(fmt string, date time.Time) time.Time { + d, err := time.ParseDuration(fmt) + if err != nil { + return date + } + return date.Add(d) +} + +func mustDateModify(fmt string, date time.Time) (time.Time, error) { + d, err := time.ParseDuration(fmt) + if err != nil { + return time.Time{}, err + } + return date.Add(d), nil +} + +func dateAgo(date interface{}) string { + var t time.Time + + switch date := date.(type) { + default: + t = time.Now() + case time.Time: + t = date + case int64: + t = time.Unix(date, 0) + case int: + t = time.Unix(int64(date), 0) + } + // Drop resolution to seconds + duration := time.Since(t).Round(time.Second) + return duration.String() +} + +func duration(sec interface{}) string { + var n int64 + switch value := sec.(type) { + default: + n = 0 + case string: + n, _ = strconv.ParseInt(value, 10, 64) + case int64: + n = value + } + return (time.Duration(n) * time.Second).String() +} + +func durationRound(duration interface{}) string { + var d time.Duration + switch duration := duration.(type) { + default: + d = 0 + case string: + d, _ = time.ParseDuration(duration) + case int64: + d = time.Duration(duration) + case time.Time: + d = time.Since(duration) + } + + u := uint64(d) + neg := d < 0 + if neg { + u = -u + } + + var ( + year = uint64(time.Hour) * 24 * 365 + month = uint64(time.Hour) * 24 * 30 + day = uint64(time.Hour) * 24 + hour = uint64(time.Hour) + minute = uint64(time.Minute) + second = uint64(time.Second) + ) + switch { + case u > year: + return strconv.FormatUint(u/year, 10) + "y" + case u > month: + return strconv.FormatUint(u/month, 10) + "mo" + case u > day: + return strconv.FormatUint(u/day, 10) + "d" + case u > hour: + return strconv.FormatUint(u/hour, 10) + "h" + case u > minute: + return strconv.FormatUint(u/minute, 10) + "m" + case u > second: + return strconv.FormatUint(u/second, 10) + "s" + } + return "0s" +} + +func toDate(fmt, str string) time.Time { + t, _ := time.ParseInLocation(fmt, str, time.Local) + return t +} + +func mustToDate(fmt, str string) (time.Time, error) { + return time.ParseInLocation(fmt, str, time.Local) +} + +func unixEpoch(date time.Time) string { + return strconv.FormatInt(date.Unix(), 10) +} diff --git a/vendor/github.com/go-task/slim-sprig/defaults.go b/vendor/github.com/go-task/slim-sprig/defaults.go new file mode 100644 index 00000000..b9f97966 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/defaults.go @@ -0,0 +1,163 @@ +package sprig + +import ( + "bytes" + "encoding/json" + "math/rand" + "reflect" + "strings" + "time" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +// dfault checks whether `given` is set, and returns default if not set. +// +// This returns `d` if `given` appears not to be set, and `given` otherwise. +// +// For numeric types 0 is unset. +// For strings, maps, arrays, and slices, len() = 0 is considered unset. +// For bool, false is unset. +// Structs are never considered unset. +// +// For everything else, including pointers, a nil value is unset. +func dfault(d interface{}, given ...interface{}) interface{} { + + if empty(given) || empty(given[0]) { + return d + } + return given[0] +} + +// empty returns true if the given value has the zero value for its type. +func empty(given interface{}) bool { + g := reflect.ValueOf(given) + if !g.IsValid() { + return true + } + + // Basically adapted from text/template.isTrue + switch g.Kind() { + default: + return g.IsNil() + case reflect.Array, reflect.Slice, reflect.Map, reflect.String: + return g.Len() == 0 + case reflect.Bool: + return !g.Bool() + case reflect.Complex64, reflect.Complex128: + return g.Complex() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return g.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return g.Uint() == 0 + case reflect.Float32, reflect.Float64: + return g.Float() == 0 + case reflect.Struct: + return false + } +} + +// coalesce returns the first non-empty value. +func coalesce(v ...interface{}) interface{} { + for _, val := range v { + if !empty(val) { + return val + } + } + return nil +} + +// all returns true if empty(x) is false for all values x in the list. +// If the list is empty, return true. +func all(v ...interface{}) bool { + for _, val := range v { + if empty(val) { + return false + } + } + return true +} + +// any returns true if empty(x) is false for any x in the list. +// If the list is empty, return false. +func any(v ...interface{}) bool { + for _, val := range v { + if !empty(val) { + return true + } + } + return false +} + +// fromJson decodes JSON into a structured value, ignoring errors. +func fromJson(v string) interface{} { + output, _ := mustFromJson(v) + return output +} + +// mustFromJson decodes JSON into a structured value, returning errors. +func mustFromJson(v string) (interface{}, error) { + var output interface{} + err := json.Unmarshal([]byte(v), &output) + return output, err +} + +// toJson encodes an item into a JSON string +func toJson(v interface{}) string { + output, _ := json.Marshal(v) + return string(output) +} + +func mustToJson(v interface{}) (string, error) { + output, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(output), nil +} + +// toPrettyJson encodes an item into a pretty (indented) JSON string +func toPrettyJson(v interface{}) string { + output, _ := json.MarshalIndent(v, "", " ") + return string(output) +} + +func mustToPrettyJson(v interface{}) (string, error) { + output, err := json.MarshalIndent(v, "", " ") + if err != nil { + return "", err + } + return string(output), nil +} + +// toRawJson encodes an item into a JSON string with no escaping of HTML characters. +func toRawJson(v interface{}) string { + output, err := mustToRawJson(v) + if err != nil { + panic(err) + } + return string(output) +} + +// mustToRawJson encodes an item into a JSON string with no escaping of HTML characters. +func mustToRawJson(v interface{}) (string, error) { + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + err := enc.Encode(&v) + if err != nil { + return "", err + } + return strings.TrimSuffix(buf.String(), "\n"), nil +} + +// ternary returns the first value if the last value is true, otherwise returns the second value. +func ternary(vt interface{}, vf interface{}, v bool) interface{} { + if v { + return vt + } + + return vf +} diff --git a/vendor/github.com/go-task/slim-sprig/dict.go b/vendor/github.com/go-task/slim-sprig/dict.go new file mode 100644 index 00000000..77ebc61b --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/dict.go @@ -0,0 +1,118 @@ +package sprig + +func get(d map[string]interface{}, key string) interface{} { + if val, ok := d[key]; ok { + return val + } + return "" +} + +func set(d map[string]interface{}, key string, value interface{}) map[string]interface{} { + d[key] = value + return d +} + +func unset(d map[string]interface{}, key string) map[string]interface{} { + delete(d, key) + return d +} + +func hasKey(d map[string]interface{}, key string) bool { + _, ok := d[key] + return ok +} + +func pluck(key string, d ...map[string]interface{}) []interface{} { + res := []interface{}{} + for _, dict := range d { + if val, ok := dict[key]; ok { + res = append(res, val) + } + } + return res +} + +func keys(dicts ...map[string]interface{}) []string { + k := []string{} + for _, dict := range dicts { + for key := range dict { + k = append(k, key) + } + } + return k +} + +func pick(dict map[string]interface{}, keys ...string) map[string]interface{} { + res := map[string]interface{}{} + for _, k := range keys { + if v, ok := dict[k]; ok { + res[k] = v + } + } + return res +} + +func omit(dict map[string]interface{}, keys ...string) map[string]interface{} { + res := map[string]interface{}{} + + omit := make(map[string]bool, len(keys)) + for _, k := range keys { + omit[k] = true + } + + for k, v := range dict { + if _, ok := omit[k]; !ok { + res[k] = v + } + } + return res +} + +func dict(v ...interface{}) map[string]interface{} { + dict := map[string]interface{}{} + lenv := len(v) + for i := 0; i < lenv; i += 2 { + key := strval(v[i]) + if i+1 >= lenv { + dict[key] = "" + continue + } + dict[key] = v[i+1] + } + return dict +} + +func values(dict map[string]interface{}) []interface{} { + values := []interface{}{} + for _, value := range dict { + values = append(values, value) + } + + return values +} + +func dig(ps ...interface{}) (interface{}, error) { + if len(ps) < 3 { + panic("dig needs at least three arguments") + } + dict := ps[len(ps)-1].(map[string]interface{}) + def := ps[len(ps)-2] + ks := make([]string, len(ps)-2) + for i := 0; i < len(ks); i++ { + ks[i] = ps[i].(string) + } + + return digFromDict(dict, def, ks) +} + +func digFromDict(dict map[string]interface{}, d interface{}, ks []string) (interface{}, error) { + k, ns := ks[0], ks[1:len(ks)] + step, has := dict[k] + if !has { + return d, nil + } + if len(ns) == 0 { + return step, nil + } + return digFromDict(step.(map[string]interface{}), d, ns) +} diff --git a/vendor/github.com/go-task/slim-sprig/doc.go b/vendor/github.com/go-task/slim-sprig/doc.go new file mode 100644 index 00000000..aabb9d44 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/doc.go @@ -0,0 +1,19 @@ +/* +Package sprig provides template functions for Go. + +This package contains a number of utility functions for working with data +inside of Go `html/template` and `text/template` files. + +To add these functions, use the `template.Funcs()` method: + + t := templates.New("foo").Funcs(sprig.FuncMap()) + +Note that you should add the function map before you parse any template files. + + In several cases, Sprig reverses the order of arguments from the way they + appear in the standard library. This is to make it easier to pipe + arguments into functions. + +See http://masterminds.github.io/sprig/ for more detailed documentation on each of the available functions. +*/ +package sprig diff --git a/vendor/github.com/go-task/slim-sprig/functions.go b/vendor/github.com/go-task/slim-sprig/functions.go new file mode 100644 index 00000000..5ea74f89 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/functions.go @@ -0,0 +1,317 @@ +package sprig + +import ( + "errors" + "html/template" + "math/rand" + "os" + "path" + "path/filepath" + "reflect" + "strconv" + "strings" + ttemplate "text/template" + "time" +) + +// FuncMap produces the function map. +// +// Use this to pass the functions into the template engine: +// +// tpl := template.New("foo").Funcs(sprig.FuncMap())) +// +func FuncMap() template.FuncMap { + return HtmlFuncMap() +} + +// HermeticTxtFuncMap returns a 'text/template'.FuncMap with only repeatable functions. +func HermeticTxtFuncMap() ttemplate.FuncMap { + r := TxtFuncMap() + for _, name := range nonhermeticFunctions { + delete(r, name) + } + return r +} + +// HermeticHtmlFuncMap returns an 'html/template'.Funcmap with only repeatable functions. +func HermeticHtmlFuncMap() template.FuncMap { + r := HtmlFuncMap() + for _, name := range nonhermeticFunctions { + delete(r, name) + } + return r +} + +// TxtFuncMap returns a 'text/template'.FuncMap +func TxtFuncMap() ttemplate.FuncMap { + return ttemplate.FuncMap(GenericFuncMap()) +} + +// HtmlFuncMap returns an 'html/template'.Funcmap +func HtmlFuncMap() template.FuncMap { + return template.FuncMap(GenericFuncMap()) +} + +// GenericFuncMap returns a copy of the basic function map as a map[string]interface{}. +func GenericFuncMap() map[string]interface{} { + gfm := make(map[string]interface{}, len(genericMap)) + for k, v := range genericMap { + gfm[k] = v + } + return gfm +} + +// These functions are not guaranteed to evaluate to the same result for given input, because they +// refer to the environment or global state. +var nonhermeticFunctions = []string{ + // Date functions + "date", + "date_in_zone", + "date_modify", + "now", + "htmlDate", + "htmlDateInZone", + "dateInZone", + "dateModify", + + // Strings + "randAlphaNum", + "randAlpha", + "randAscii", + "randNumeric", + "randBytes", + "uuidv4", + + // OS + "env", + "expandenv", + + // Network + "getHostByName", +} + +var genericMap = map[string]interface{}{ + "hello": func() string { return "Hello!" }, + + // Date functions + "ago": dateAgo, + "date": date, + "date_in_zone": dateInZone, + "date_modify": dateModify, + "dateInZone": dateInZone, + "dateModify": dateModify, + "duration": duration, + "durationRound": durationRound, + "htmlDate": htmlDate, + "htmlDateInZone": htmlDateInZone, + "must_date_modify": mustDateModify, + "mustDateModify": mustDateModify, + "mustToDate": mustToDate, + "now": time.Now, + "toDate": toDate, + "unixEpoch": unixEpoch, + + // Strings + "trunc": trunc, + "trim": strings.TrimSpace, + "upper": strings.ToUpper, + "lower": strings.ToLower, + "title": strings.Title, + "substr": substring, + // Switch order so that "foo" | repeat 5 + "repeat": func(count int, str string) string { return strings.Repeat(str, count) }, + // Deprecated: Use trimAll. + "trimall": func(a, b string) string { return strings.Trim(b, a) }, + // Switch order so that "$foo" | trimall "$" + "trimAll": func(a, b string) string { return strings.Trim(b, a) }, + "trimSuffix": func(a, b string) string { return strings.TrimSuffix(b, a) }, + "trimPrefix": func(a, b string) string { return strings.TrimPrefix(b, a) }, + // Switch order so that "foobar" | contains "foo" + "contains": func(substr string, str string) bool { return strings.Contains(str, substr) }, + "hasPrefix": func(substr string, str string) bool { return strings.HasPrefix(str, substr) }, + "hasSuffix": func(substr string, str string) bool { return strings.HasSuffix(str, substr) }, + "quote": quote, + "squote": squote, + "cat": cat, + "indent": indent, + "nindent": nindent, + "replace": replace, + "plural": plural, + "sha1sum": sha1sum, + "sha256sum": sha256sum, + "adler32sum": adler32sum, + "toString": strval, + + // Wrap Atoi to stop errors. + "atoi": func(a string) int { i, _ := strconv.Atoi(a); return i }, + "int64": toInt64, + "int": toInt, + "float64": toFloat64, + "seq": seq, + "toDecimal": toDecimal, + + //"gt": func(a, b int) bool {return a > b}, + //"gte": func(a, b int) bool {return a >= b}, + //"lt": func(a, b int) bool {return a < b}, + //"lte": func(a, b int) bool {return a <= b}, + + // split "/" foo/bar returns map[int]string{0: foo, 1: bar} + "split": split, + "splitList": func(sep, orig string) []string { return strings.Split(orig, sep) }, + // splitn "/" foo/bar/fuu returns map[int]string{0: foo, 1: bar/fuu} + "splitn": splitn, + "toStrings": strslice, + + "until": until, + "untilStep": untilStep, + + // VERY basic arithmetic. + "add1": func(i interface{}) int64 { return toInt64(i) + 1 }, + "add": func(i ...interface{}) int64 { + var a int64 = 0 + for _, b := range i { + a += toInt64(b) + } + return a + }, + "sub": func(a, b interface{}) int64 { return toInt64(a) - toInt64(b) }, + "div": func(a, b interface{}) int64 { return toInt64(a) / toInt64(b) }, + "mod": func(a, b interface{}) int64 { return toInt64(a) % toInt64(b) }, + "mul": func(a interface{}, v ...interface{}) int64 { + val := toInt64(a) + for _, b := range v { + val = val * toInt64(b) + } + return val + }, + "randInt": func(min, max int) int { return rand.Intn(max-min) + min }, + "biggest": max, + "max": max, + "min": min, + "maxf": maxf, + "minf": minf, + "ceil": ceil, + "floor": floor, + "round": round, + + // string slices. Note that we reverse the order b/c that's better + // for template processing. + "join": join, + "sortAlpha": sortAlpha, + + // Defaults + "default": dfault, + "empty": empty, + "coalesce": coalesce, + "all": all, + "any": any, + "compact": compact, + "mustCompact": mustCompact, + "fromJson": fromJson, + "toJson": toJson, + "toPrettyJson": toPrettyJson, + "toRawJson": toRawJson, + "mustFromJson": mustFromJson, + "mustToJson": mustToJson, + "mustToPrettyJson": mustToPrettyJson, + "mustToRawJson": mustToRawJson, + "ternary": ternary, + + // Reflection + "typeOf": typeOf, + "typeIs": typeIs, + "typeIsLike": typeIsLike, + "kindOf": kindOf, + "kindIs": kindIs, + "deepEqual": reflect.DeepEqual, + + // OS: + "env": os.Getenv, + "expandenv": os.ExpandEnv, + + // Network: + "getHostByName": getHostByName, + + // Paths: + "base": path.Base, + "dir": path.Dir, + "clean": path.Clean, + "ext": path.Ext, + "isAbs": path.IsAbs, + + // Filepaths: + "osBase": filepath.Base, + "osClean": filepath.Clean, + "osDir": filepath.Dir, + "osExt": filepath.Ext, + "osIsAbs": filepath.IsAbs, + + // Encoding: + "b64enc": base64encode, + "b64dec": base64decode, + "b32enc": base32encode, + "b32dec": base32decode, + + // Data Structures: + "tuple": list, // FIXME: with the addition of append/prepend these are no longer immutable. + "list": list, + "dict": dict, + "get": get, + "set": set, + "unset": unset, + "hasKey": hasKey, + "pluck": pluck, + "keys": keys, + "pick": pick, + "omit": omit, + "values": values, + + "append": push, "push": push, + "mustAppend": mustPush, "mustPush": mustPush, + "prepend": prepend, + "mustPrepend": mustPrepend, + "first": first, + "mustFirst": mustFirst, + "rest": rest, + "mustRest": mustRest, + "last": last, + "mustLast": mustLast, + "initial": initial, + "mustInitial": mustInitial, + "reverse": reverse, + "mustReverse": mustReverse, + "uniq": uniq, + "mustUniq": mustUniq, + "without": without, + "mustWithout": mustWithout, + "has": has, + "mustHas": mustHas, + "slice": slice, + "mustSlice": mustSlice, + "concat": concat, + "dig": dig, + "chunk": chunk, + "mustChunk": mustChunk, + + // Flow Control: + "fail": func(msg string) (string, error) { return "", errors.New(msg) }, + + // Regex + "regexMatch": regexMatch, + "mustRegexMatch": mustRegexMatch, + "regexFindAll": regexFindAll, + "mustRegexFindAll": mustRegexFindAll, + "regexFind": regexFind, + "mustRegexFind": mustRegexFind, + "regexReplaceAll": regexReplaceAll, + "mustRegexReplaceAll": mustRegexReplaceAll, + "regexReplaceAllLiteral": regexReplaceAllLiteral, + "mustRegexReplaceAllLiteral": mustRegexReplaceAllLiteral, + "regexSplit": regexSplit, + "mustRegexSplit": mustRegexSplit, + "regexQuoteMeta": regexQuoteMeta, + + // URLs: + "urlParse": urlParse, + "urlJoin": urlJoin, +} diff --git a/vendor/github.com/go-task/slim-sprig/list.go b/vendor/github.com/go-task/slim-sprig/list.go new file mode 100644 index 00000000..ca0fbb78 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/list.go @@ -0,0 +1,464 @@ +package sprig + +import ( + "fmt" + "math" + "reflect" + "sort" +) + +// Reflection is used in these functions so that slices and arrays of strings, +// ints, and other types not implementing []interface{} can be worked with. +// For example, this is useful if you need to work on the output of regexs. + +func list(v ...interface{}) []interface{} { + return v +} + +func push(list interface{}, v interface{}) []interface{} { + l, err := mustPush(list, v) + if err != nil { + panic(err) + } + + return l +} + +func mustPush(list interface{}, v interface{}) ([]interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + nl := make([]interface{}, l) + for i := 0; i < l; i++ { + nl[i] = l2.Index(i).Interface() + } + + return append(nl, v), nil + + default: + return nil, fmt.Errorf("Cannot push on type %s", tp) + } +} + +func prepend(list interface{}, v interface{}) []interface{} { + l, err := mustPrepend(list, v) + if err != nil { + panic(err) + } + + return l +} + +func mustPrepend(list interface{}, v interface{}) ([]interface{}, error) { + //return append([]interface{}{v}, list...) + + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + nl := make([]interface{}, l) + for i := 0; i < l; i++ { + nl[i] = l2.Index(i).Interface() + } + + return append([]interface{}{v}, nl...), nil + + default: + return nil, fmt.Errorf("Cannot prepend on type %s", tp) + } +} + +func chunk(size int, list interface{}) [][]interface{} { + l, err := mustChunk(size, list) + if err != nil { + panic(err) + } + + return l +} + +func mustChunk(size int, list interface{}) ([][]interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + + cs := int(math.Floor(float64(l-1)/float64(size)) + 1) + nl := make([][]interface{}, cs) + + for i := 0; i < cs; i++ { + clen := size + if i == cs-1 { + clen = int(math.Floor(math.Mod(float64(l), float64(size)))) + if clen == 0 { + clen = size + } + } + + nl[i] = make([]interface{}, clen) + + for j := 0; j < clen; j++ { + ix := i*size + j + nl[i][j] = l2.Index(ix).Interface() + } + } + + return nl, nil + + default: + return nil, fmt.Errorf("Cannot chunk type %s", tp) + } +} + +func last(list interface{}) interface{} { + l, err := mustLast(list) + if err != nil { + panic(err) + } + + return l +} + +func mustLast(list interface{}) (interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + if l == 0 { + return nil, nil + } + + return l2.Index(l - 1).Interface(), nil + default: + return nil, fmt.Errorf("Cannot find last on type %s", tp) + } +} + +func first(list interface{}) interface{} { + l, err := mustFirst(list) + if err != nil { + panic(err) + } + + return l +} + +func mustFirst(list interface{}) (interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + if l == 0 { + return nil, nil + } + + return l2.Index(0).Interface(), nil + default: + return nil, fmt.Errorf("Cannot find first on type %s", tp) + } +} + +func rest(list interface{}) []interface{} { + l, err := mustRest(list) + if err != nil { + panic(err) + } + + return l +} + +func mustRest(list interface{}) ([]interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + if l == 0 { + return nil, nil + } + + nl := make([]interface{}, l-1) + for i := 1; i < l; i++ { + nl[i-1] = l2.Index(i).Interface() + } + + return nl, nil + default: + return nil, fmt.Errorf("Cannot find rest on type %s", tp) + } +} + +func initial(list interface{}) []interface{} { + l, err := mustInitial(list) + if err != nil { + panic(err) + } + + return l +} + +func mustInitial(list interface{}) ([]interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + if l == 0 { + return nil, nil + } + + nl := make([]interface{}, l-1) + for i := 0; i < l-1; i++ { + nl[i] = l2.Index(i).Interface() + } + + return nl, nil + default: + return nil, fmt.Errorf("Cannot find initial on type %s", tp) + } +} + +func sortAlpha(list interface{}) []string { + k := reflect.Indirect(reflect.ValueOf(list)).Kind() + switch k { + case reflect.Slice, reflect.Array: + a := strslice(list) + s := sort.StringSlice(a) + s.Sort() + return s + } + return []string{strval(list)} +} + +func reverse(v interface{}) []interface{} { + l, err := mustReverse(v) + if err != nil { + panic(err) + } + + return l +} + +func mustReverse(v interface{}) ([]interface{}, error) { + tp := reflect.TypeOf(v).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(v) + + l := l2.Len() + // We do not sort in place because the incoming array should not be altered. + nl := make([]interface{}, l) + for i := 0; i < l; i++ { + nl[l-i-1] = l2.Index(i).Interface() + } + + return nl, nil + default: + return nil, fmt.Errorf("Cannot find reverse on type %s", tp) + } +} + +func compact(list interface{}) []interface{} { + l, err := mustCompact(list) + if err != nil { + panic(err) + } + + return l +} + +func mustCompact(list interface{}) ([]interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + nl := []interface{}{} + var item interface{} + for i := 0; i < l; i++ { + item = l2.Index(i).Interface() + if !empty(item) { + nl = append(nl, item) + } + } + + return nl, nil + default: + return nil, fmt.Errorf("Cannot compact on type %s", tp) + } +} + +func uniq(list interface{}) []interface{} { + l, err := mustUniq(list) + if err != nil { + panic(err) + } + + return l +} + +func mustUniq(list interface{}) ([]interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + dest := []interface{}{} + var item interface{} + for i := 0; i < l; i++ { + item = l2.Index(i).Interface() + if !inList(dest, item) { + dest = append(dest, item) + } + } + + return dest, nil + default: + return nil, fmt.Errorf("Cannot find uniq on type %s", tp) + } +} + +func inList(haystack []interface{}, needle interface{}) bool { + for _, h := range haystack { + if reflect.DeepEqual(needle, h) { + return true + } + } + return false +} + +func without(list interface{}, omit ...interface{}) []interface{} { + l, err := mustWithout(list, omit...) + if err != nil { + panic(err) + } + + return l +} + +func mustWithout(list interface{}, omit ...interface{}) ([]interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + res := []interface{}{} + var item interface{} + for i := 0; i < l; i++ { + item = l2.Index(i).Interface() + if !inList(omit, item) { + res = append(res, item) + } + } + + return res, nil + default: + return nil, fmt.Errorf("Cannot find without on type %s", tp) + } +} + +func has(needle interface{}, haystack interface{}) bool { + l, err := mustHas(needle, haystack) + if err != nil { + panic(err) + } + + return l +} + +func mustHas(needle interface{}, haystack interface{}) (bool, error) { + if haystack == nil { + return false, nil + } + tp := reflect.TypeOf(haystack).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(haystack) + var item interface{} + l := l2.Len() + for i := 0; i < l; i++ { + item = l2.Index(i).Interface() + if reflect.DeepEqual(needle, item) { + return true, nil + } + } + + return false, nil + default: + return false, fmt.Errorf("Cannot find has on type %s", tp) + } +} + +// $list := [1, 2, 3, 4, 5] +// slice $list -> list[0:5] = list[:] +// slice $list 0 3 -> list[0:3] = list[:3] +// slice $list 3 5 -> list[3:5] +// slice $list 3 -> list[3:5] = list[3:] +func slice(list interface{}, indices ...interface{}) interface{} { + l, err := mustSlice(list, indices...) + if err != nil { + panic(err) + } + + return l +} + +func mustSlice(list interface{}, indices ...interface{}) (interface{}, error) { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + + l := l2.Len() + if l == 0 { + return nil, nil + } + + var start, end int + if len(indices) > 0 { + start = toInt(indices[0]) + } + if len(indices) < 2 { + end = l + } else { + end = toInt(indices[1]) + } + + return l2.Slice(start, end).Interface(), nil + default: + return nil, fmt.Errorf("list should be type of slice or array but %s", tp) + } +} + +func concat(lists ...interface{}) interface{} { + var res []interface{} + for _, list := range lists { + tp := reflect.TypeOf(list).Kind() + switch tp { + case reflect.Slice, reflect.Array: + l2 := reflect.ValueOf(list) + for i := 0; i < l2.Len(); i++ { + res = append(res, l2.Index(i).Interface()) + } + default: + panic(fmt.Sprintf("Cannot concat type %s as list", tp)) + } + } + return res +} diff --git a/vendor/github.com/go-task/slim-sprig/network.go b/vendor/github.com/go-task/slim-sprig/network.go new file mode 100644 index 00000000..108d78a9 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/network.go @@ -0,0 +1,12 @@ +package sprig + +import ( + "math/rand" + "net" +) + +func getHostByName(name string) string { + addrs, _ := net.LookupHost(name) + //TODO: add error handing when release v3 comes out + return addrs[rand.Intn(len(addrs))] +} diff --git a/vendor/github.com/go-task/slim-sprig/numeric.go b/vendor/github.com/go-task/slim-sprig/numeric.go new file mode 100644 index 00000000..98cbb37a --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/numeric.go @@ -0,0 +1,228 @@ +package sprig + +import ( + "fmt" + "math" + "reflect" + "strconv" + "strings" +) + +// toFloat64 converts 64-bit floats +func toFloat64(v interface{}) float64 { + if str, ok := v.(string); ok { + iv, err := strconv.ParseFloat(str, 64) + if err != nil { + return 0 + } + return iv + } + + val := reflect.Indirect(reflect.ValueOf(v)) + switch val.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + return float64(val.Int()) + case reflect.Uint8, reflect.Uint16, reflect.Uint32: + return float64(val.Uint()) + case reflect.Uint, reflect.Uint64: + return float64(val.Uint()) + case reflect.Float32, reflect.Float64: + return val.Float() + case reflect.Bool: + if val.Bool() { + return 1 + } + return 0 + default: + return 0 + } +} + +func toInt(v interface{}) int { + //It's not optimal. Bud I don't want duplicate toInt64 code. + return int(toInt64(v)) +} + +// toInt64 converts integer types to 64-bit integers +func toInt64(v interface{}) int64 { + if str, ok := v.(string); ok { + iv, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return 0 + } + return iv + } + + val := reflect.Indirect(reflect.ValueOf(v)) + switch val.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + return val.Int() + case reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(val.Uint()) + case reflect.Uint, reflect.Uint64: + tv := val.Uint() + if tv <= math.MaxInt64 { + return int64(tv) + } + // TODO: What is the sensible thing to do here? + return math.MaxInt64 + case reflect.Float32, reflect.Float64: + return int64(val.Float()) + case reflect.Bool: + if val.Bool() { + return 1 + } + return 0 + default: + return 0 + } +} + +func max(a interface{}, i ...interface{}) int64 { + aa := toInt64(a) + for _, b := range i { + bb := toInt64(b) + if bb > aa { + aa = bb + } + } + return aa +} + +func maxf(a interface{}, i ...interface{}) float64 { + aa := toFloat64(a) + for _, b := range i { + bb := toFloat64(b) + aa = math.Max(aa, bb) + } + return aa +} + +func min(a interface{}, i ...interface{}) int64 { + aa := toInt64(a) + for _, b := range i { + bb := toInt64(b) + if bb < aa { + aa = bb + } + } + return aa +} + +func minf(a interface{}, i ...interface{}) float64 { + aa := toFloat64(a) + for _, b := range i { + bb := toFloat64(b) + aa = math.Min(aa, bb) + } + return aa +} + +func until(count int) []int { + step := 1 + if count < 0 { + step = -1 + } + return untilStep(0, count, step) +} + +func untilStep(start, stop, step int) []int { + v := []int{} + + if stop < start { + if step >= 0 { + return v + } + for i := start; i > stop; i += step { + v = append(v, i) + } + return v + } + + if step <= 0 { + return v + } + for i := start; i < stop; i += step { + v = append(v, i) + } + return v +} + +func floor(a interface{}) float64 { + aa := toFloat64(a) + return math.Floor(aa) +} + +func ceil(a interface{}) float64 { + aa := toFloat64(a) + return math.Ceil(aa) +} + +func round(a interface{}, p int, rOpt ...float64) float64 { + roundOn := .5 + if len(rOpt) > 0 { + roundOn = rOpt[0] + } + val := toFloat64(a) + places := toFloat64(p) + + var round float64 + pow := math.Pow(10, places) + digit := pow * val + _, div := math.Modf(digit) + if div >= roundOn { + round = math.Ceil(digit) + } else { + round = math.Floor(digit) + } + return round / pow +} + +// converts unix octal to decimal +func toDecimal(v interface{}) int64 { + result, err := strconv.ParseInt(fmt.Sprint(v), 8, 64) + if err != nil { + return 0 + } + return result +} + +func seq(params ...int) string { + increment := 1 + switch len(params) { + case 0: + return "" + case 1: + start := 1 + end := params[0] + if end < start { + increment = -1 + } + return intArrayToString(untilStep(start, end+increment, increment), " ") + case 3: + start := params[0] + end := params[2] + step := params[1] + if end < start { + increment = -1 + if step > 0 { + return "" + } + } + return intArrayToString(untilStep(start, end+increment, step), " ") + case 2: + start := params[0] + end := params[1] + step := 1 + if end < start { + step = -1 + } + return intArrayToString(untilStep(start, end+step, step), " ") + default: + return "" + } +} + +func intArrayToString(slice []int, delimeter string) string { + return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(slice)), delimeter), "[]") +} diff --git a/vendor/github.com/go-task/slim-sprig/reflect.go b/vendor/github.com/go-task/slim-sprig/reflect.go new file mode 100644 index 00000000..8a65c132 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/reflect.go @@ -0,0 +1,28 @@ +package sprig + +import ( + "fmt" + "reflect" +) + +// typeIs returns true if the src is the type named in target. +func typeIs(target string, src interface{}) bool { + return target == typeOf(src) +} + +func typeIsLike(target string, src interface{}) bool { + t := typeOf(src) + return target == t || "*"+target == t +} + +func typeOf(src interface{}) string { + return fmt.Sprintf("%T", src) +} + +func kindIs(target string, src interface{}) bool { + return target == kindOf(src) +} + +func kindOf(src interface{}) string { + return reflect.ValueOf(src).Kind().String() +} diff --git a/vendor/github.com/go-task/slim-sprig/regex.go b/vendor/github.com/go-task/slim-sprig/regex.go new file mode 100644 index 00000000..fab55101 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/regex.go @@ -0,0 +1,83 @@ +package sprig + +import ( + "regexp" +) + +func regexMatch(regex string, s string) bool { + match, _ := regexp.MatchString(regex, s) + return match +} + +func mustRegexMatch(regex string, s string) (bool, error) { + return regexp.MatchString(regex, s) +} + +func regexFindAll(regex string, s string, n int) []string { + r := regexp.MustCompile(regex) + return r.FindAllString(s, n) +} + +func mustRegexFindAll(regex string, s string, n int) ([]string, error) { + r, err := regexp.Compile(regex) + if err != nil { + return []string{}, err + } + return r.FindAllString(s, n), nil +} + +func regexFind(regex string, s string) string { + r := regexp.MustCompile(regex) + return r.FindString(s) +} + +func mustRegexFind(regex string, s string) (string, error) { + r, err := regexp.Compile(regex) + if err != nil { + return "", err + } + return r.FindString(s), nil +} + +func regexReplaceAll(regex string, s string, repl string) string { + r := regexp.MustCompile(regex) + return r.ReplaceAllString(s, repl) +} + +func mustRegexReplaceAll(regex string, s string, repl string) (string, error) { + r, err := regexp.Compile(regex) + if err != nil { + return "", err + } + return r.ReplaceAllString(s, repl), nil +} + +func regexReplaceAllLiteral(regex string, s string, repl string) string { + r := regexp.MustCompile(regex) + return r.ReplaceAllLiteralString(s, repl) +} + +func mustRegexReplaceAllLiteral(regex string, s string, repl string) (string, error) { + r, err := regexp.Compile(regex) + if err != nil { + return "", err + } + return r.ReplaceAllLiteralString(s, repl), nil +} + +func regexSplit(regex string, s string, n int) []string { + r := regexp.MustCompile(regex) + return r.Split(s, n) +} + +func mustRegexSplit(regex string, s string, n int) ([]string, error) { + r, err := regexp.Compile(regex) + if err != nil { + return []string{}, err + } + return r.Split(s, n), nil +} + +func regexQuoteMeta(s string) string { + return regexp.QuoteMeta(s) +} diff --git a/vendor/github.com/go-task/slim-sprig/strings.go b/vendor/github.com/go-task/slim-sprig/strings.go new file mode 100644 index 00000000..3c62d6b6 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/strings.go @@ -0,0 +1,189 @@ +package sprig + +import ( + "encoding/base32" + "encoding/base64" + "fmt" + "reflect" + "strconv" + "strings" +) + +func base64encode(v string) string { + return base64.StdEncoding.EncodeToString([]byte(v)) +} + +func base64decode(v string) string { + data, err := base64.StdEncoding.DecodeString(v) + if err != nil { + return err.Error() + } + return string(data) +} + +func base32encode(v string) string { + return base32.StdEncoding.EncodeToString([]byte(v)) +} + +func base32decode(v string) string { + data, err := base32.StdEncoding.DecodeString(v) + if err != nil { + return err.Error() + } + return string(data) +} + +func quote(str ...interface{}) string { + out := make([]string, 0, len(str)) + for _, s := range str { + if s != nil { + out = append(out, fmt.Sprintf("%q", strval(s))) + } + } + return strings.Join(out, " ") +} + +func squote(str ...interface{}) string { + out := make([]string, 0, len(str)) + for _, s := range str { + if s != nil { + out = append(out, fmt.Sprintf("'%v'", s)) + } + } + return strings.Join(out, " ") +} + +func cat(v ...interface{}) string { + v = removeNilElements(v) + r := strings.TrimSpace(strings.Repeat("%v ", len(v))) + return fmt.Sprintf(r, v...) +} + +func indent(spaces int, v string) string { + pad := strings.Repeat(" ", spaces) + return pad + strings.Replace(v, "\n", "\n"+pad, -1) +} + +func nindent(spaces int, v string) string { + return "\n" + indent(spaces, v) +} + +func replace(old, new, src string) string { + return strings.Replace(src, old, new, -1) +} + +func plural(one, many string, count int) string { + if count == 1 { + return one + } + return many +} + +func strslice(v interface{}) []string { + switch v := v.(type) { + case []string: + return v + case []interface{}: + b := make([]string, 0, len(v)) + for _, s := range v { + if s != nil { + b = append(b, strval(s)) + } + } + return b + default: + val := reflect.ValueOf(v) + switch val.Kind() { + case reflect.Array, reflect.Slice: + l := val.Len() + b := make([]string, 0, l) + for i := 0; i < l; i++ { + value := val.Index(i).Interface() + if value != nil { + b = append(b, strval(value)) + } + } + return b + default: + if v == nil { + return []string{} + } + + return []string{strval(v)} + } + } +} + +func removeNilElements(v []interface{}) []interface{} { + newSlice := make([]interface{}, 0, len(v)) + for _, i := range v { + if i != nil { + newSlice = append(newSlice, i) + } + } + return newSlice +} + +func strval(v interface{}) string { + switch v := v.(type) { + case string: + return v + case []byte: + return string(v) + case error: + return v.Error() + case fmt.Stringer: + return v.String() + default: + return fmt.Sprintf("%v", v) + } +} + +func trunc(c int, s string) string { + if c < 0 && len(s)+c > 0 { + return s[len(s)+c:] + } + if c >= 0 && len(s) > c { + return s[:c] + } + return s +} + +func join(sep string, v interface{}) string { + return strings.Join(strslice(v), sep) +} + +func split(sep, orig string) map[string]string { + parts := strings.Split(orig, sep) + res := make(map[string]string, len(parts)) + for i, v := range parts { + res["_"+strconv.Itoa(i)] = v + } + return res +} + +func splitn(sep string, n int, orig string) map[string]string { + parts := strings.SplitN(orig, sep, n) + res := make(map[string]string, len(parts)) + for i, v := range parts { + res["_"+strconv.Itoa(i)] = v + } + return res +} + +// substring creates a substring of the given string. +// +// If start is < 0, this calls string[:end]. +// +// If start is >= 0 and end < 0 or end bigger than s length, this calls string[start:] +// +// Otherwise, this calls string[start, end]. +func substring(start, end int, s string) string { + if start < 0 { + return s[:end] + } + if end < 0 || end > len(s) { + return s[start:] + } + return s[start:end] +} diff --git a/vendor/github.com/go-task/slim-sprig/url.go b/vendor/github.com/go-task/slim-sprig/url.go new file mode 100644 index 00000000..b8e120e1 --- /dev/null +++ b/vendor/github.com/go-task/slim-sprig/url.go @@ -0,0 +1,66 @@ +package sprig + +import ( + "fmt" + "net/url" + "reflect" +) + +func dictGetOrEmpty(dict map[string]interface{}, key string) string { + value, ok := dict[key] + if !ok { + return "" + } + tp := reflect.TypeOf(value).Kind() + if tp != reflect.String { + panic(fmt.Sprintf("unable to parse %s key, must be of type string, but %s found", key, tp.String())) + } + return reflect.ValueOf(value).String() +} + +// parses given URL to return dict object +func urlParse(v string) map[string]interface{} { + dict := map[string]interface{}{} + parsedURL, err := url.Parse(v) + if err != nil { + panic(fmt.Sprintf("unable to parse url: %s", err)) + } + dict["scheme"] = parsedURL.Scheme + dict["host"] = parsedURL.Host + dict["hostname"] = parsedURL.Hostname() + dict["path"] = parsedURL.Path + dict["query"] = parsedURL.RawQuery + dict["opaque"] = parsedURL.Opaque + dict["fragment"] = parsedURL.Fragment + if parsedURL.User != nil { + dict["userinfo"] = parsedURL.User.String() + } else { + dict["userinfo"] = "" + } + + return dict +} + +// join given dict to URL string +func urlJoin(d map[string]interface{}) string { + resURL := url.URL{ + Scheme: dictGetOrEmpty(d, "scheme"), + Host: dictGetOrEmpty(d, "host"), + Path: dictGetOrEmpty(d, "path"), + RawQuery: dictGetOrEmpty(d, "query"), + Opaque: dictGetOrEmpty(d, "opaque"), + Fragment: dictGetOrEmpty(d, "fragment"), + } + userinfo := dictGetOrEmpty(d, "userinfo") + var user *url.Userinfo + if userinfo != "" { + tempURL, err := url.Parse(fmt.Sprintf("proto://%s@host", userinfo)) + if err != nil { + panic(fmt.Sprintf("unable to parse userinfo in dict: %s", err)) + } + user = tempURL.User + } + + resURL.User = user + return resURL.String() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/.gitignore b/vendor/github.com/lucas-clemente/quic-go/.gitignore new file mode 100644 index 00000000..3cc06f24 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/.gitignore @@ -0,0 +1,17 @@ +debug +debug.test +main +mockgen_tmp.go +*.qtr +*.qlog +*.txt +race.[0-9]* + +fuzzing/*/*.zip +fuzzing/*/coverprofile +fuzzing/*/crashers +fuzzing/*/sonarprofile +fuzzing/*/suppressions +fuzzing/*/corpus/ + +gomock_reflect_*/ diff --git a/vendor/github.com/lucas-clemente/quic-go/.golangci.yml b/vendor/github.com/lucas-clemente/quic-go/.golangci.yml new file mode 100644 index 00000000..2589c053 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/.golangci.yml @@ -0,0 +1,44 @@ +run: + skip-files: + - internal/qtls/structs_equal_test.go + +linters-settings: + depguard: + type: blacklist + packages: + - github.com/marten-seemann/qtls + packages-with-error-message: + - github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls" + misspell: + ignore-words: + - ect + +linters: + disable-all: true + enable: + - asciicheck + - deadcode + - depguard + - exhaustive + - exportloopref + - goimports + - gofmt # redundant, since gofmt *should* be a no-op after gofumpt + - gofumpt + - gosimple + - ineffassign + - misspell + - prealloc + - staticcheck + - stylecheck + - structcheck + - unconvert + - unparam + - unused + - varcheck + - vet + +issues: + exclude-rules: + - path: internal/qtls + linters: + - depguard diff --git a/vendor/github.com/lucas-clemente/quic-go/Changelog.md b/vendor/github.com/lucas-clemente/quic-go/Changelog.md new file mode 100644 index 00000000..c1c33232 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/Changelog.md @@ -0,0 +1,109 @@ +# Changelog + +## v0.22.0 (2021-07-25) + +- Use `ReadBatch` to read multiple UDP packets from the socket with a single syscall +- Add a config option (`Config.DisableVersionNegotiationPackets`) to disable sending of Version Negotiation packets +- Drop support for QUIC draft versions 32 and 34 +- Remove the `RetireBugBackwardsCompatibilityMode`, which was intended to mitigate a bug when retiring connection IDs in quic-go in v0.17.2 and ealier + +## v0.21.2 (2021-07-15) + +- Update qtls (for Go 1.15, 1.16 and 1.17rc1) to include the fix for the crypto/tls panic (see https://groups.google.com/g/golang-dev/c/5LJ2V7rd-Ag/m/YGLHVBZ6AAAJ for details) + +## v0.21.0 (2021-06-01) + +- quic-go now supports RFC 9000! + +## v0.20.0 (2021-03-19) + +- Remove the `quic.Config.HandshakeTimeout`. Introduce a `quic.Config.HandshakeIdleTimeout`. + +## v0.17.1 (2020-06-20) + +- Supports QUIC WG draft-29. +- Improve bundling of ACK frames (#2543). + +## v0.16.0 (2020-05-31) + +- Supports QUIC WG draft-28. + +## v0.15.0 (2020-03-01) + +- Supports QUIC WG draft-27. +- Add support for 0-RTT. +- Remove `Session.Close()`. Applications need to pass an application error code to the transport using `Session.CloseWithError()`. +- Make the TLS Cipher Suites configurable (via `tls.Config.CipherSuites`). + +## v0.14.0 (2019-12-04) + +- Supports QUIC WG draft-24. + +## v0.13.0 (2019-11-05) + +- Supports QUIC WG draft-23. +- Add an `EarlyListener` that allows sending of 0.5-RTT data. +- Add a `TokenStore` to store address validation tokens. +- Issue and use new connection IDs during a connection. + +## v0.12.0 (2019-08-05) + +- Implement HTTP/3. +- Rename `quic.Cookie` to `quic.Token` and `quic.Config.AcceptCookie` to `quic.Config.AcceptToken`. +- Distinguish between Retry tokens and tokens sent in NEW_TOKEN frames. +- Enforce application protocol negotiation (via `tls.Config.NextProtos`). +- Use a varint for error codes. +- Add support for [quic-trace](https://github.com/google/quic-trace). +- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`. +- Implement TLS key updates. + +## v0.11.0 (2019-04-05) + +- Drop support for gQUIC. For qQUIC support, please switch to the *gquic* branch. +- Implement QUIC WG draft-19. +- Use [qtls](https://github.com/marten-seemann/qtls) for TLS 1.3. +- Return a `tls.ConnectionState` from `quic.Session.ConnectionState()`. +- Remove the error return values from `quic.Stream.CancelRead()` and `quic.Stream.CancelWrite()` + +## v0.10.0 (2018-08-28) + +- Add support for QUIC 44, drop support for QUIC 42. + +## v0.9.0 (2018-08-15) + +- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC). +- Split Session.Close into one method for regular closing and one for closing with an error. + +## v0.8.0 (2018-06-26) + +- Add support for unidirectional streams (for IETF QUIC). +- Add a `quic.Config` option for the maximum number of incoming streams. +- Add support for QUIC 42 and 43. +- Add dial functions that use a context. +- Multiplex clients on a net.PacketConn, when using Dial(conn). + +## v0.7.0 (2018-02-03) + +- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored. +- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`. +- Expose the `ConnectionState` in the `Session` (experimental API). +- Implement packet pacing. + +## v0.6.0 (2017-12-12) + +- Add support for QUIC 39, drop support for QUIC 35 - 37 +- Added `quic.Config` options for maximal flow control windows +- Add a `quic.Config` option for QUIC versions +- Add a `quic.Config` option to request omission of the connection ID from a server +- Add a `quic.Config` option to configure the source address validation +- Add a `quic.Config` option to configure the handshake timeout +- Add a `quic.Config` option to configure the idle timeout +- Add a `quic.Config` option to configure keep-alive +- Rename the STK to Cookie +- Implement `net.Conn`-style deadlines for streams +- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details. +- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details. +- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper` +- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn` +- Drop support for Go 1.7 and 1.8. +- Various bugfixes diff --git a/vendor/github.com/lucas-clemente/quic-go/LICENSE b/vendor/github.com/lucas-clemente/quic-go/LICENSE new file mode 100644 index 00000000..51378bef --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016 the quic-go authors & Google, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/lucas-clemente/quic-go/README.md b/vendor/github.com/lucas-clemente/quic-go/README.md new file mode 100644 index 00000000..f284174c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/README.md @@ -0,0 +1,61 @@ +# A QUIC implementation in pure Go + + + +[![PkgGoDev](https://pkg.go.dev/badge/github.com/lucas-clemente/quic-go)](https://pkg.go.dev/github.com/lucas-clemente/quic-go) +[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/) + +quic-go is an implementation of the [QUIC protocol, RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000) protocol in Go, including the [Unreliable Datagram Extension, RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221). + +In addition to RFC 9000, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem. + +## Guides + +*We currently support Go 1.16.x, Go 1.17.x, and Go 1.18.x.* + +Running tests: + + go test ./... + +### QUIC without HTTP/3 + +Take a look at [this echo example](example/echo/echo.go). + +## Usage + +### As a server + +See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go: + +```go +http.Handle("/", http.FileServer(http.Dir(wwwDir))) +http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil) +``` + +### As a client + +See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`. + +```go +http.Client{ + Transport: &http3.RoundTripper{}, +} +``` + +## Projects using quic-go + +| Project | Description | Stars | +|------------------------------------------------------|--------------------------------------------------------------------------------------------------------|-------| +| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | +| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | +| [go-ipfs](https://github.com/ipfs/go-ipfs) | IPFS implementation in go | ![GitHub Repo stars](https://img.shields.io/github/stars/ipfs/go-ipfs?style=flat-square) | +| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | +| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | +| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | +| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | +| [OONI Probe](https://github.com/ooni/probe-cli) | The Open Observatory of Network Interference (OONI) aims to empower decentralized efforts in documenting Internet censorship around the world. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | + + +## Contributing + +We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/lucas-clemente/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment. diff --git a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go b/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go new file mode 100644 index 00000000..c0b7067d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go @@ -0,0 +1,80 @@ +package quic + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type packetBuffer struct { + Data []byte + + // refCount counts how many packets Data is used in. + // It doesn't support concurrent use. + // It is > 1 when used for coalesced packet. + refCount int +} + +// Split increases the refCount. +// It must be called when a packet buffer is used for more than one packet, +// e.g. when splitting coalesced packets. +func (b *packetBuffer) Split() { + b.refCount++ +} + +// Decrement decrements the reference counter. +// It doesn't put the buffer back into the pool. +func (b *packetBuffer) Decrement() { + b.refCount-- + if b.refCount < 0 { + panic("negative packetBuffer refCount") + } +} + +// MaybeRelease puts the packet buffer back into the pool, +// if the reference counter already reached 0. +func (b *packetBuffer) MaybeRelease() { + // only put the packetBuffer back if it's not used any more + if b.refCount == 0 { + b.putBack() + } +} + +// Release puts back the packet buffer into the pool. +// It should be called when processing is definitely finished. +func (b *packetBuffer) Release() { + b.Decrement() + if b.refCount != 0 { + panic("packetBuffer refCount not zero") + } + b.putBack() +} + +// Len returns the length of Data +func (b *packetBuffer) Len() protocol.ByteCount { + return protocol.ByteCount(len(b.Data)) +} + +func (b *packetBuffer) putBack() { + if cap(b.Data) != int(protocol.MaxPacketBufferSize) { + panic("putPacketBuffer called with packet of wrong size!") + } + bufferPool.Put(b) +} + +var bufferPool sync.Pool + +func getPacketBuffer() *packetBuffer { + buf := bufferPool.Get().(*packetBuffer) + buf.refCount = 1 + buf.Data = buf.Data[:0] + return buf +} + +func init() { + bufferPool.New = func() interface{} { + return &packetBuffer{ + Data: make([]byte, 0, protocol.MaxPacketBufferSize), + } + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/client.go b/vendor/github.com/lucas-clemente/quic-go/client.go new file mode 100644 index 00000000..be8390e6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/client.go @@ -0,0 +1,339 @@ +package quic + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "strings" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/logging" +) + +type client struct { + sconn sendConn + // If the client is created with DialAddr, we create a packet conn. + // If it is started with Dial, we take a packet conn as a parameter. + createdPacketConn bool + + use0RTT bool + + packetHandlers packetHandlerManager + + tlsConf *tls.Config + config *Config + + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID + + initialPacketNumber protocol.PacketNumber + hasNegotiatedVersion bool + version protocol.VersionNumber + + handshakeChan chan struct{} + + conn quicConn + + tracer logging.ConnectionTracer + tracingID uint64 + logger utils.Logger +} + +var ( + // make it possible to mock connection ID generation in the tests + generateConnectionID = protocol.GenerateConnectionID + generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial +) + +// DialAddr establishes a new QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. +// The hostname for SNI is taken from the given address. +// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. +func DialAddr( + addr string, + tlsConf *tls.Config, + config *Config, +) (Connection, error) { + return DialAddrContext(context.Background(), addr, tlsConf, config) +} + +// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. +// The hostname for SNI is taken from the given address. +// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. +func DialAddrEarly( + addr string, + tlsConf *tls.Config, + config *Config, +) (EarlyConnection, error) { + return DialAddrEarlyContext(context.Background(), addr, tlsConf, config) +} + +// DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context. +// See DialAddrEarly for details +func DialAddrEarlyContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, +) (EarlyConnection, error) { + conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) + if err != nil { + return nil, err + } + utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection") + return conn, nil +} + +// DialAddrContext establishes a new QUIC connection to a server using the provided context. +// See DialAddr for details. +func DialAddrContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, +) (Connection, error) { + return dialAddrContext(ctx, addr, tlsConf, config, false) +} + +func dialAddrContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, + use0RTT bool, +) (quicConn, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true) +} + +// Dial establishes a new QUIC connection to a server using a net.PacketConn. If +// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn +// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP +// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write +// packets. The same PacketConn can be used for multiple calls to Dial and +// Listen, QUIC connection IDs are used for demultiplexing the different +// connections. The host parameter is used for SNI. The tls.Config must define +// an application protocol (using NextProtos). +func Dial( + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (Connection, error) { + return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false) +} + +// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. +// The same PacketConn can be used for multiple calls to Dial and Listen, +// QUIC connection IDs are used for demultiplexing the different connections. +// The host parameter is used for SNI. +// The tls.Config must define an application protocol (using NextProtos). +func DialEarly( + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (EarlyConnection, error) { + return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) +} + +// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. +// See DialEarly for details. +func DialEarlyContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (EarlyConnection, error) { + return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false) +} + +// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. +// See Dial for details. +func DialContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (Connection, error) { + return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false) +} + +func dialContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, + use0RTT bool, + createdPacketConn bool, +) (quicConn, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(config); err != nil { + return nil, err + } + config = populateClientConfig(config, createdPacketConn) + packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + if err != nil { + return nil, err + } + c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn) + if err != nil { + return nil, err + } + c.packetHandlers = packetHandlers + + c.tracingID = nextConnTracingID() + if c.config.Tracer != nil { + c.tracer = c.config.Tracer.TracerForConnection( + context.WithValue(ctx, ConnectionTracingKey, c.tracingID), + protocol.PerspectiveClient, + c.destConnID, + ) + } + if c.tracer != nil { + c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) + } + if err := c.dial(ctx); err != nil { + return nil, err + } + return c.conn, nil +} + +func newClient( + pconn net.PacketConn, + remoteAddr net.Addr, + config *Config, + tlsConf *tls.Config, + host string, + use0RTT bool, + createdPacketConn bool, +) (*client, error) { + if tlsConf == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() + } + if tlsConf.ServerName == "" { + sni := host + if strings.IndexByte(sni, ':') != -1 { + var err error + sni, _, err = net.SplitHostPort(sni) + if err != nil { + return nil, err + } + } + + tlsConf.ServerName = sni + } + + // check that all versions are actually supported + if config != nil { + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return nil, fmt.Errorf("%s is not a valid QUIC version", v) + } + } + } + + srcConnID, err := generateConnectionID(config.ConnectionIDLength) + if err != nil { + return nil, err + } + destConnID, err := generateConnectionIDForInitial() + if err != nil { + return nil, err + } + c := &client{ + srcConnID: srcConnID, + destConnID: destConnID, + sconn: newSendPconn(pconn, remoteAddr), + createdPacketConn: createdPacketConn, + use0RTT: use0RTT, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), + } + return c, nil +} + +func (c *client) dial(ctx context.Context) error { + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + + c.conn = newClientConnection( + c.sconn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + ) + c.packetHandlers.Add(c.srcConnID, c.conn) + + errorChan := make(chan error, 1) + go func() { + err := c.conn.run() // returns as soon as the connection is closed + + if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { + c.packetHandlers.Destroy() + } + errorChan <- err + }() + + // only set when we're using 0-RTT + // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. + var earlyConnChan <-chan struct{} + if c.use0RTT { + earlyConnChan = c.conn.earlyConnReady() + } + + select { + case <-ctx.Done(): + c.conn.shutdown() + return ctx.Err() + case err := <-errorChan: + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true + return c.dial(ctx) + } + return err + case <-earlyConnChan: + // ready to send 0-RTT data + return nil + case <-c.conn.HandshakeComplete().Done(): + // handshake successfully completed + return nil + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/closed_conn.go b/vendor/github.com/lucas-clemente/quic-go/closed_conn.go new file mode 100644 index 00000000..35c2d739 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/closed_conn.go @@ -0,0 +1,112 @@ +package quic + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// A closedLocalConn is a connection that we closed locally. +// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, +// with an exponential backoff. +type closedLocalConn struct { + conn sendConn + connClosePacket []byte + + closeOnce sync.Once + closeChan chan struct{} // is closed when the connection is closed or destroyed + + receivedPackets chan *receivedPacket + counter uint64 // number of packets received + + perspective protocol.Perspective + + logger utils.Logger +} + +var _ packetHandler = &closedLocalConn{} + +// newClosedLocalConn creates a new closedLocalConn and runs it. +func newClosedLocalConn( + conn sendConn, + connClosePacket []byte, + perspective protocol.Perspective, + logger utils.Logger, +) packetHandler { + s := &closedLocalConn{ + conn: conn, + connClosePacket: connClosePacket, + perspective: perspective, + logger: logger, + closeChan: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, 64), + } + go s.run() + return s +} + +func (s *closedLocalConn) run() { + for { + select { + case p := <-s.receivedPackets: + s.handlePacketImpl(p) + case <-s.closeChan: + return + } + } +} + +func (s *closedLocalConn) handlePacket(p *receivedPacket) { + select { + case s.receivedPackets <- p: + default: + } +} + +func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) { + s.counter++ + // exponential backoff + // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving + for n := s.counter; n > 1; n = n / 2 { + if n%2 != 0 { + return + } + } + s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter) + if err := s.conn.Write(s.connClosePacket); err != nil { + s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err) + } +} + +func (s *closedLocalConn) shutdown() { + s.destroy(nil) +} + +func (s *closedLocalConn) destroy(error) { + s.closeOnce.Do(func() { + close(s.closeChan) + }) +} + +func (s *closedLocalConn) getPerspective() protocol.Perspective { + return s.perspective +} + +// A closedRemoteConn is a connection that was closed remotely. +// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. +// We can just ignore those packets. +type closedRemoteConn struct { + perspective protocol.Perspective +} + +var _ packetHandler = &closedRemoteConn{} + +func newClosedRemoteConn(pers protocol.Perspective) packetHandler { + return &closedRemoteConn{perspective: pers} +} + +func (s *closedRemoteConn) handlePacket(*receivedPacket) {} +func (s *closedRemoteConn) shutdown() {} +func (s *closedRemoteConn) destroy(error) {} +func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } diff --git a/vendor/github.com/lucas-clemente/quic-go/codecov.yml b/vendor/github.com/lucas-clemente/quic-go/codecov.yml new file mode 100644 index 00000000..ee9cfd3b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/codecov.yml @@ -0,0 +1,21 @@ +coverage: + round: nearest + ignore: + - streams_map_incoming_bidi.go + - streams_map_incoming_uni.go + - streams_map_outgoing_bidi.go + - streams_map_outgoing_uni.go + - http3/gzip_reader.go + - interop/ + - internal/ackhandler/packet_linkedlist.go + - internal/utils/byteinterval_linkedlist.go + - internal/utils/newconnectionid_linkedlist.go + - internal/utils/packetinterval_linkedlist.go + - internal/utils/linkedlist/linkedlist.go + - fuzzing/ + - metrics/ + status: + project: + default: + threshold: 0.5 + patch: false diff --git a/vendor/github.com/lucas-clemente/quic-go/config.go b/vendor/github.com/lucas-clemente/quic-go/config.go new file mode 100644 index 00000000..5d969a12 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/config.go @@ -0,0 +1,124 @@ +package quic + +import ( + "errors" + "time" + + "github.com/lucas-clemente/quic-go/internal/utils" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// Clone clones a Config +func (c *Config) Clone() *Config { + copy := *c + return © +} + +func (c *Config) handshakeTimeout() time.Duration { + return utils.MaxDuration(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) +} + +func validateConfig(config *Config) error { + if config == nil { + return nil + } + if config.MaxIncomingStreams > 1<<60 { + return errors.New("invalid value for Config.MaxIncomingStreams") + } + if config.MaxIncomingUniStreams > 1<<60 { + return errors.New("invalid value for Config.MaxIncomingUniStreams") + } + return nil +} + +// populateServerConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil +func populateServerConfig(config *Config) *Config { + config = populateConfig(config) + if config.ConnectionIDLength == 0 { + config.ConnectionIDLength = protocol.DefaultConnectionIDLength + } + if config.AcceptToken == nil { + config.AcceptToken = defaultAcceptToken + } + return config +} + +// populateClientConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil +func populateClientConfig(config *Config, createdPacketConn bool) *Config { + config = populateConfig(config) + if config.ConnectionIDLength == 0 && !createdPacketConn { + config.ConnectionIDLength = protocol.DefaultConnectionIDLength + } + return config +} + +func populateConfig(config *Config) *Config { + if config == nil { + config = &Config{} + } + versions := config.Versions + if len(versions) == 0 { + versions = protocol.SupportedVersions + } + handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout + if config.HandshakeIdleTimeout != 0 { + handshakeIdleTimeout = config.HandshakeIdleTimeout + } + idleTimeout := protocol.DefaultIdleTimeout + if config.MaxIdleTimeout != 0 { + idleTimeout = config.MaxIdleTimeout + } + initialStreamReceiveWindow := config.InitialStreamReceiveWindow + if initialStreamReceiveWindow == 0 { + initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData + } + maxStreamReceiveWindow := config.MaxStreamReceiveWindow + if maxStreamReceiveWindow == 0 { + maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow + } + initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow + if initialConnectionReceiveWindow == 0 { + initialConnectionReceiveWindow = protocol.DefaultInitialMaxData + } + maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow + if maxConnectionReceiveWindow == 0 { + maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow + } + maxIncomingStreams := config.MaxIncomingStreams + if maxIncomingStreams == 0 { + maxIncomingStreams = protocol.DefaultMaxIncomingStreams + } else if maxIncomingStreams < 0 { + maxIncomingStreams = 0 + } + maxIncomingUniStreams := config.MaxIncomingUniStreams + if maxIncomingUniStreams == 0 { + maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams + } else if maxIncomingUniStreams < 0 { + maxIncomingUniStreams = 0 + } + + return &Config{ + Versions: versions, + HandshakeIdleTimeout: handshakeIdleTimeout, + MaxIdleTimeout: idleTimeout, + AcceptToken: config.AcceptToken, + KeepAlivePeriod: config.KeepAlivePeriod, + InitialStreamReceiveWindow: initialStreamReceiveWindow, + MaxStreamReceiveWindow: maxStreamReceiveWindow, + InitialConnectionReceiveWindow: initialConnectionReceiveWindow, + MaxConnectionReceiveWindow: maxConnectionReceiveWindow, + AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + ConnectionIDLength: config.ConnectionIDLength, + StatelessResetKey: config.StatelessResetKey, + TokenStore: config.TokenStore, + EnableDatagrams: config.EnableDatagrams, + DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, + DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, + Tracer: config.Tracer, + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/conn_id_generator.go b/vendor/github.com/lucas-clemente/quic-go/conn_id_generator.go new file mode 100644 index 00000000..90c2b7a6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/conn_id_generator.go @@ -0,0 +1,140 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type connIDGenerator struct { + connIDLen int + highestSeq uint64 + + activeSrcConnIDs map[uint64]protocol.ConnectionID + initialClientDestConnID protocol.ConnectionID + + addConnectionID func(protocol.ConnectionID) + getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken + removeConnectionID func(protocol.ConnectionID) + retireConnectionID func(protocol.ConnectionID) + replaceWithClosed func(protocol.ConnectionID, packetHandler) + queueControlFrame func(wire.Frame) + + version protocol.VersionNumber +} + +func newConnIDGenerator( + initialConnectionID protocol.ConnectionID, + initialClientDestConnID protocol.ConnectionID, // nil for the client + addConnectionID func(protocol.ConnectionID), + getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, + removeConnectionID func(protocol.ConnectionID), + retireConnectionID func(protocol.ConnectionID), + replaceWithClosed func(protocol.ConnectionID, packetHandler), + queueControlFrame func(wire.Frame), + version protocol.VersionNumber, +) *connIDGenerator { + m := &connIDGenerator{ + connIDLen: initialConnectionID.Len(), + activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), + addConnectionID: addConnectionID, + getStatelessResetToken: getStatelessResetToken, + removeConnectionID: removeConnectionID, + retireConnectionID: retireConnectionID, + replaceWithClosed: replaceWithClosed, + queueControlFrame: queueControlFrame, + version: version, + } + m.activeSrcConnIDs[0] = initialConnectionID + m.initialClientDestConnID = initialClientDestConnID + return m +} + +func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { + if m.connIDLen == 0 { + return nil + } + // The active_connection_id_limit transport parameter is the number of + // connection IDs the peer will store. This limit includes the connection ID + // used during the handshake, and the one sent in the preferred_address + // transport parameter. + // We currently don't send the preferred_address transport parameter, + // so we can issue (limit - 1) connection IDs. + for i := uint64(len(m.activeSrcConnIDs)); i < utils.MinUint64(limit, protocol.MaxIssuedConnectionIDs); i++ { + if err := m.issueNewConnID(); err != nil { + return err + } + } + return nil +} + +func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error { + if seq > m.highestSeq { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq), + } + } + connID, ok := m.activeSrcConnIDs[seq] + // We might already have deleted this connection ID, if this is a duplicate frame. + if !ok { + return nil + } + if connID.Equal(sentWithDestConnID) { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), + } + } + m.retireConnectionID(connID) + delete(m.activeSrcConnIDs, seq) + // Don't issue a replacement for the initial connection ID. + if seq == 0 { + return nil + } + return m.issueNewConnID() +} + +func (m *connIDGenerator) issueNewConnID() error { + connID, err := protocol.GenerateConnectionID(m.connIDLen) + if err != nil { + return err + } + m.activeSrcConnIDs[m.highestSeq+1] = connID + m.addConnectionID(connID) + m.queueControlFrame(&wire.NewConnectionIDFrame{ + SequenceNumber: m.highestSeq + 1, + ConnectionID: connID, + StatelessResetToken: m.getStatelessResetToken(connID), + }) + m.highestSeq++ + return nil +} + +func (m *connIDGenerator) SetHandshakeComplete() { + if m.initialClientDestConnID != nil { + m.retireConnectionID(m.initialClientDestConnID) + m.initialClientDestConnID = nil + } +} + +func (m *connIDGenerator) RemoveAll() { + if m.initialClientDestConnID != nil { + m.removeConnectionID(m.initialClientDestConnID) + } + for _, connID := range m.activeSrcConnIDs { + m.removeConnectionID(connID) + } +} + +func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { + if m.initialClientDestConnID != nil { + m.replaceWithClosed(m.initialClientDestConnID, handler) + } + for _, connID := range m.activeSrcConnIDs { + m.replaceWithClosed(connID, handler) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/conn_id_manager.go b/vendor/github.com/lucas-clemente/quic-go/conn_id_manager.go new file mode 100644 index 00000000..e1b025a9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/conn_id_manager.go @@ -0,0 +1,207 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type connIDManager struct { + queue utils.NewConnectionIDList + + handshakeComplete bool + activeSequenceNumber uint64 + highestRetired uint64 + activeConnectionID protocol.ConnectionID + activeStatelessResetToken *protocol.StatelessResetToken + + // We change the connection ID after sending on average + // protocol.PacketsPerConnectionID packets. The actual value is randomized + // hide the packet loss rate from on-path observers. + rand utils.Rand + packetsSinceLastChange uint32 + packetsPerConnectionID uint32 + + addStatelessResetToken func(protocol.StatelessResetToken) + removeStatelessResetToken func(protocol.StatelessResetToken) + queueControlFrame func(wire.Frame) +} + +func newConnIDManager( + initialDestConnID protocol.ConnectionID, + addStatelessResetToken func(protocol.StatelessResetToken), + removeStatelessResetToken func(protocol.StatelessResetToken), + queueControlFrame func(wire.Frame), +) *connIDManager { + return &connIDManager{ + activeConnectionID: initialDestConnID, + addStatelessResetToken: addStatelessResetToken, + removeStatelessResetToken: removeStatelessResetToken, + queueControlFrame: queueControlFrame, + } +} + +func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { + return h.addConnectionID(1, connID, resetToken) +} + +func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { + if err := h.add(f); err != nil { + return err + } + if h.queue.Len() >= protocol.MaxActiveConnectionIDs { + return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} + } + return nil +} + +func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { + // If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active + // connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately. + if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired { + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: f.SequenceNumber, + }) + return nil + } + + // Retire elements in the queue. + // Doesn't retire the active connection ID. + if f.RetirePriorTo > h.highestRetired { + var next *utils.NewConnectionIDElement + for el := h.queue.Front(); el != nil; el = next { + if el.Value.SequenceNumber >= f.RetirePriorTo { + break + } + next = el.Next() + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: el.Value.SequenceNumber, + }) + h.queue.Remove(el) + } + h.highestRetired = f.RetirePriorTo + } + + if f.SequenceNumber == h.activeSequenceNumber { + return nil + } + + if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil { + return err + } + + // Retire the active connection ID, if necessary. + if h.activeSequenceNumber < f.RetirePriorTo { + // The queue is guaranteed to have at least one element at this point. + h.updateConnectionID() + } + return nil +} + +func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { + // insert a new element at the end + if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq { + h.queue.PushBack(utils.NewConnectionID{ + SequenceNumber: seq, + ConnectionID: connID, + StatelessResetToken: resetToken, + }) + return nil + } + // insert a new element somewhere in the middle + for el := h.queue.Front(); el != nil; el = el.Next() { + if el.Value.SequenceNumber == seq { + if !el.Value.ConnectionID.Equal(connID) { + return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) + } + if el.Value.StatelessResetToken != resetToken { + return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq) + } + break + } + if el.Value.SequenceNumber > seq { + h.queue.InsertBefore(utils.NewConnectionID{ + SequenceNumber: seq, + ConnectionID: connID, + StatelessResetToken: resetToken, + }, el) + break + } + } + return nil +} + +func (h *connIDManager) updateConnectionID() { + h.queueControlFrame(&wire.RetireConnectionIDFrame{ + SequenceNumber: h.activeSequenceNumber, + }) + h.highestRetired = utils.MaxUint64(h.highestRetired, h.activeSequenceNumber) + if h.activeStatelessResetToken != nil { + h.removeStatelessResetToken(*h.activeStatelessResetToken) + } + + front := h.queue.Remove(h.queue.Front()) + h.activeSequenceNumber = front.SequenceNumber + h.activeConnectionID = front.ConnectionID + h.activeStatelessResetToken = &front.StatelessResetToken + h.packetsSinceLastChange = 0 + h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID)) + h.addStatelessResetToken(*h.activeStatelessResetToken) +} + +func (h *connIDManager) Close() { + if h.activeStatelessResetToken != nil { + h.removeStatelessResetToken(*h.activeStatelessResetToken) + } +} + +// is called when the server performs a Retry +// and when the server changes the connection ID in the first Initial sent +func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { + if h.activeSequenceNumber != 0 { + panic("expected first connection ID to have sequence number 0") + } + h.activeConnectionID = newConnID +} + +// is called when the server provides a stateless reset token in the transport parameters +func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) { + if h.activeSequenceNumber != 0 { + panic("expected first connection ID to have sequence number 0") + } + h.activeStatelessResetToken = &token + h.addStatelessResetToken(token) +} + +func (h *connIDManager) SentPacket() { + h.packetsSinceLastChange++ +} + +func (h *connIDManager) shouldUpdateConnID() bool { + if !h.handshakeComplete { + return false + } + // initiate the first change as early as possible (after handshake completion) + if h.queue.Len() > 0 && h.activeSequenceNumber == 0 { + return true + } + // For later changes, only change if + // 1. The queue of connection IDs is filled more than 50%. + // 2. We sent at least PacketsPerConnectionID packets + return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs && + h.packetsSinceLastChange >= h.packetsPerConnectionID +} + +func (h *connIDManager) Get() protocol.ConnectionID { + if h.shouldUpdateConnID() { + h.updateConnectionID() + } + return h.activeConnectionID +} + +func (h *connIDManager) SetHandshakeComplete() { + h.handshakeComplete = true +} diff --git a/vendor/github.com/lucas-clemente/quic-go/connection.go b/vendor/github.com/lucas-clemente/quic-go/connection.go new file mode 100644 index 00000000..ce45af86 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/connection.go @@ -0,0 +1,2006 @@ +package quic + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "reflect" + "sync" + "sync/atomic" + "time" + + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/logutils" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/logging" +) + +type unpacker interface { + Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) +} + +type streamGetter interface { + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) +} + +type streamManager interface { + GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error) + GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error) + OpenStream() (Stream, error) + OpenUniStream() (SendStream, error) + OpenStreamSync(context.Context) (Stream, error) + OpenUniStreamSync(context.Context) (SendStream, error) + AcceptStream(context.Context) (Stream, error) + AcceptUniStream(context.Context) (ReceiveStream, error) + DeleteStream(protocol.StreamID) error + UpdateLimits(*wire.TransportParameters) + HandleMaxStreamsFrame(*wire.MaxStreamsFrame) + CloseWithError(error) + ResetFor0RTT() + UseResetMaps() +} + +type cryptoStreamHandler interface { + RunHandshake() + ChangeConnectionID(protocol.ConnectionID) + SetLargest1RTTAcked(protocol.PacketNumber) error + SetHandshakeConfirmed() + GetSessionTicket() ([]byte, error) + io.Closer + ConnectionState() handshake.ConnectionState +} + +type packetInfo struct { + addr net.IP + ifIndex uint32 +} + +type receivedPacket struct { + buffer *packetBuffer + + remoteAddr net.Addr + rcvTime time.Time + data []byte + + ecn protocol.ECN + + info *packetInfo +} + +func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } + +func (p *receivedPacket) Clone() *receivedPacket { + return &receivedPacket{ + remoteAddr: p.remoteAddr, + rcvTime: p.rcvTime, + data: p.data, + buffer: p.buffer, + ecn: p.ecn, + info: p.info, + } +} + +type connRunner interface { + Add(protocol.ConnectionID, packetHandler) bool + GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken + Retire(protocol.ConnectionID) + Remove(protocol.ConnectionID) + ReplaceWithClosed(protocol.ConnectionID, packetHandler) + AddResetToken(protocol.StatelessResetToken, packetHandler) + RemoveResetToken(protocol.StatelessResetToken) +} + +type handshakeRunner struct { + onReceivedParams func(*wire.TransportParameters) + onError func(error) + dropKeys func(protocol.EncryptionLevel) + onHandshakeComplete func() +} + +func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) } +func (r *handshakeRunner) OnError(e error) { r.onError(e) } +func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } +func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() } + +type closeError struct { + err error + remote bool + immediate bool +} + +type errCloseForRecreating struct { + nextPacketNumber protocol.PacketNumber + nextVersion protocol.VersionNumber +} + +func (e *errCloseForRecreating) Error() string { + return "closing connection in order to recreate it" +} + +var connTracingID uint64 // to be accessed atomically +func nextConnTracingID() uint64 { return atomic.AddUint64(&connTracingID, 1) } + +// A Connection is a QUIC connection +type connection struct { + // Destination connection ID used during the handshake. + // Used to check source connection ID on incoming packets. + handshakeDestConnID protocol.ConnectionID + // Set for the client. Destination connection ID used on the first Initial sent. + origDestConnID protocol.ConnectionID + retrySrcConnID *protocol.ConnectionID // only set for the client (and if a Retry was performed) + + srcConnIDLen int + + perspective protocol.Perspective + version protocol.VersionNumber + config *Config + + conn sendConn + sendQueue sender + + streamsMap streamManager + connIDManager *connIDManager + connIDGenerator *connIDGenerator + + rttStats *utils.RTTStats + + cryptoStreamManager *cryptoStreamManager + sentPacketHandler ackhandler.SentPacketHandler + receivedPacketHandler ackhandler.ReceivedPacketHandler + retransmissionQueue *retransmissionQueue + framer framer + windowUpdateQueue *windowUpdateQueue + connFlowController flowcontrol.ConnectionFlowController + tokenStoreKey string // only set for the client + tokenGenerator *handshake.TokenGenerator // only set for the server + + unpacker unpacker + frameParser wire.FrameParser + packer packer + mtuDiscoverer mtuDiscoverer // initialized when the handshake completes + + oneRTTStream cryptoStream // only set for the server + cryptoStreamHandler cryptoStreamHandler + + receivedPackets chan *receivedPacket + sendingScheduled chan struct{} + + closeOnce sync.Once + // closeChan is used to notify the run loop that it should terminate + closeChan chan closeError + + ctx context.Context + ctxCancel context.CancelFunc + handshakeCtx context.Context + handshakeCtxCancel context.CancelFunc + + undecryptablePackets []*receivedPacket // undecryptable packets, waiting for a change in encryption level + undecryptablePacketsToProcess []*receivedPacket + + clientHelloWritten <-chan *wire.TransportParameters + earlyConnReadyChan chan struct{} + handshakeCompleteChan chan struct{} // is closed when the handshake completes + sentFirstPacket bool + handshakeComplete bool + handshakeConfirmed bool + + receivedRetry bool + versionNegotiated bool + receivedFirstPacket bool + + idleTimeout time.Duration + creationTime time.Time + // The idle timeout is set based on the max of the time we received the last packet... + lastPacketReceivedTime time.Time + // ... and the time we sent a new ack-eliciting packet after receiving a packet. + firstAckElicitingPacketAfterIdleSentTime time.Time + // pacingDeadline is the time when the next packet should be sent + pacingDeadline time.Time + + peerParams *wire.TransportParameters + + timer *utils.Timer + // keepAlivePingSent stores whether a keep alive PING is in flight. + // It is reset as soon as we receive a packet from the peer. + keepAlivePingSent bool + keepAliveInterval time.Duration + + datagramQueue *datagramQueue + + logID string + tracer logging.ConnectionTracer + logger utils.Logger +} + +var ( + _ Connection = &connection{} + _ EarlyConnection = &connection{} + _ streamSender = &connection{} + deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine +) + +var newConnection = func( + conn sendConn, + runner connRunner, + origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, + clientDestConnID protocol.ConnectionID, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + statelessResetToken protocol.StatelessResetToken, + conf *Config, + tlsConf *tls.Config, + tokenGenerator *handshake.TokenGenerator, + enable0RTT bool, + tracer logging.ConnectionTracer, + tracingID uint64, + logger utils.Logger, + v protocol.VersionNumber, +) quicConn { + s := &connection{ + conn: conn, + config: conf, + handshakeDestConnID: destConnID, + srcConnIDLen: srcConnID.Len(), + tokenGenerator: tokenGenerator, + oneRTTStream: newCryptoStream(), + perspective: protocol.PerspectiveServer, + handshakeCompleteChan: make(chan struct{}), + tracer: tracer, + logger: logger, + version: v, + } + if origDestConnID != nil { + s.logID = origDestConnID.String() + } else { + s.logID = destConnID.String() + } + s.connIDManager = newConnIDManager( + destConnID, + func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, + runner.RemoveResetToken, + s.queueControlFrame, + ) + s.connIDGenerator = newConnIDGenerator( + srcConnID, + clientDestConnID, + func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + runner.GetStatelessResetToken, + runner.Remove, + runner.Retire, + runner.ReplaceWithClosed, + s.queueControlFrame, + s.version, + ) + s.preSetup() + s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) + s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( + 0, + getMaxPacketSize(s.conn.RemoteAddr()), + s.rttStats, + s.perspective, + s.tracer, + s.logger, + s.version, + ) + initialStream := newCryptoStream() + handshakeStream := newCryptoStream() + params := &wire.TransportParameters{ + InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + MaxIdleTimeout: s.config.MaxIdleTimeout, + MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), + MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), + MaxAckDelay: protocol.MaxAckDelayInclGranularity, + AckDelayExponent: protocol.AckDelayExponent, + DisableActiveMigration: true, + StatelessResetToken: &statelessResetToken, + OriginalDestinationConnectionID: origDestConnID, + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, + RetrySourceConnectionID: retrySrcConnID, + } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } + if s.tracer != nil { + s.tracer.SentTransportParameters(params) + } + cs := handshake.NewCryptoSetupServer( + initialStream, + handshakeStream, + clientDestConnID, + conn.LocalAddr(), + conn.RemoteAddr(), + params, + &handshakeRunner{ + onReceivedParams: s.handleTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + onHandshakeComplete: func() { + runner.Retire(clientDestConnID) + close(s.handshakeCompleteChan) + }, + }, + tlsConf, + enable0RTT, + s.rttStats, + tracer, + logger, + s.version, + ) + s.cryptoStreamHandler = cs + s.packer = newPacketPacker( + srcConnID, + s.connIDManager.Get, + initialStream, + handshakeStream, + s.sentPacketHandler, + s.retransmissionQueue, + s.RemoteAddr(), + cs, + s.framer, + s.receivedPacketHandler, + s.datagramQueue, + s.perspective, + s.version, + ) + s.unpacker = newPacketUnpacker(cs, s.version) + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) + return s +} + +// declare this as a variable, such that we can it mock it in the tests +var newClientConnection = func( + conn sendConn, + runner connRunner, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + conf *Config, + tlsConf *tls.Config, + initialPacketNumber protocol.PacketNumber, + enable0RTT bool, + hasNegotiatedVersion bool, + tracer logging.ConnectionTracer, + tracingID uint64, + logger utils.Logger, + v protocol.VersionNumber, +) quicConn { + s := &connection{ + conn: conn, + config: conf, + origDestConnID: destConnID, + handshakeDestConnID: destConnID, + srcConnIDLen: srcConnID.Len(), + perspective: protocol.PerspectiveClient, + handshakeCompleteChan: make(chan struct{}), + logID: destConnID.String(), + logger: logger, + tracer: tracer, + versionNegotiated: hasNegotiatedVersion, + version: v, + } + s.connIDManager = newConnIDManager( + destConnID, + func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, + runner.RemoveResetToken, + s.queueControlFrame, + ) + s.connIDGenerator = newConnIDGenerator( + srcConnID, + nil, + func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + runner.GetStatelessResetToken, + runner.Remove, + runner.Retire, + runner.ReplaceWithClosed, + s.queueControlFrame, + s.version, + ) + s.preSetup() + s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) + s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( + initialPacketNumber, + getMaxPacketSize(s.conn.RemoteAddr()), + s.rttStats, + s.perspective, + s.tracer, + s.logger, + s.version, + ) + initialStream := newCryptoStream() + handshakeStream := newCryptoStream() + params := &wire.TransportParameters{ + InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + MaxIdleTimeout: s.config.MaxIdleTimeout, + MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), + MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), + MaxAckDelay: protocol.MaxAckDelayInclGranularity, + AckDelayExponent: protocol.AckDelayExponent, + DisableActiveMigration: true, + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, + } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } + if s.tracer != nil { + s.tracer.SentTransportParameters(params) + } + cs, clientHelloWritten := handshake.NewCryptoSetupClient( + initialStream, + handshakeStream, + destConnID, + conn.LocalAddr(), + conn.RemoteAddr(), + params, + &handshakeRunner{ + onReceivedParams: s.handleTransportParameters, + onError: s.closeLocal, + dropKeys: s.dropEncryptionLevel, + onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, + }, + tlsConf, + enable0RTT, + s.rttStats, + tracer, + logger, + s.version, + ) + s.clientHelloWritten = clientHelloWritten + s.cryptoStreamHandler = cs + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) + s.unpacker = newPacketUnpacker(cs, s.version) + s.packer = newPacketPacker( + srcConnID, + s.connIDManager.Get, + initialStream, + handshakeStream, + s.sentPacketHandler, + s.retransmissionQueue, + s.RemoteAddr(), + cs, + s.framer, + s.receivedPacketHandler, + s.datagramQueue, + s.perspective, + s.version, + ) + if len(tlsConf.ServerName) > 0 { + s.tokenStoreKey = tlsConf.ServerName + } else { + s.tokenStoreKey = conn.RemoteAddr().String() + } + if s.config.TokenStore != nil { + if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil { + s.packer.SetToken(token.data) + } + } + return s +} + +func (s *connection) preSetup() { + s.sendQueue = newSendQueue(s.conn) + s.retransmissionQueue = newRetransmissionQueue(s.version) + s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) + s.rttStats = &utils.RTTStats{} + s.connFlowController = flowcontrol.NewConnectionFlowController( + protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + protocol.ByteCount(s.config.MaxConnectionReceiveWindow), + s.onHasConnectionWindowUpdate, + func(size protocol.ByteCount) bool { + if s.config.AllowConnectionWindowIncrease == nil { + return true + } + return s.config.AllowConnectionWindowIncrease(s, uint64(size)) + }, + s.rttStats, + s.logger, + ) + s.earlyConnReadyChan = make(chan struct{}) + s.streamsMap = newStreamsMap( + s, + s.newFlowController, + uint64(s.config.MaxIncomingStreams), + uint64(s.config.MaxIncomingUniStreams), + s.perspective, + s.version, + ) + s.framer = newFramer(s.streamsMap, s.version) + s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) + s.closeChan = make(chan closeError, 1) + s.sendingScheduled = make(chan struct{}, 1) + s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) + + now := time.Now() + s.lastPacketReceivedTime = now + s.creationTime = now + + s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) + if s.config.EnableDatagrams { + s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) + } +} + +// run the connection main loop +func (s *connection) run() error { + defer s.ctxCancel() + + s.timer = utils.NewTimer() + + go s.cryptoStreamHandler.RunHandshake() + go func() { + if err := s.sendQueue.Run(); err != nil { + s.destroyImpl(err) + } + }() + + if s.perspective == protocol.PerspectiveClient { + select { + case zeroRTTParams := <-s.clientHelloWritten: + s.scheduleSending() + if zeroRTTParams != nil { + s.restoreTransportParameters(zeroRTTParams) + close(s.earlyConnReadyChan) + } + case closeErr := <-s.closeChan: + // put the close error back into the channel, so that the run loop can receive it + s.closeChan <- closeErr + } + } + + var ( + closeErr closeError + sendQueueAvailable <-chan struct{} + ) + +runLoop: + for { + // Close immediately if requested + select { + case closeErr = <-s.closeChan: + break runLoop + case <-s.handshakeCompleteChan: + s.handleHandshakeComplete() + default: + } + + s.maybeResetTimer() + + var processedUndecryptablePacket bool + if len(s.undecryptablePacketsToProcess) > 0 { + queue := s.undecryptablePacketsToProcess + s.undecryptablePacketsToProcess = nil + for _, p := range queue { + if processed := s.handlePacketImpl(p); processed { + processedUndecryptablePacket = true + } + // Don't set timers and send packets if the packet made us close the connection. + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } + } + } + // If we processed any undecryptable packets, jump to the resetting of the timers directly. + if !processedUndecryptablePacket { + select { + case closeErr = <-s.closeChan: + break runLoop + case <-s.timer.Chan(): + s.timer.SetRead() + // We do all the interesting stuff after the switch statement, so + // nothing to see here. + case <-s.sendingScheduled: + // We do all the interesting stuff after the switch statement, so + // nothing to see here. + case <-sendQueueAvailable: + case firstPacket := <-s.receivedPackets: + wasProcessed := s.handlePacketImpl(firstPacket) + // Don't set timers and send packets if the packet made us close the connection. + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } + if s.handshakeComplete { + // Now process all packets in the receivedPackets channel. + // Limit the number of packets to the length of the receivedPackets channel, + // so we eventually get a chance to send out an ACK when receiving a lot of packets. + numPackets := len(s.receivedPackets) + receiveLoop: + for i := 0; i < numPackets; i++ { + select { + case p := <-s.receivedPackets: + if processed := s.handlePacketImpl(p); processed { + wasProcessed = true + } + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } + default: + break receiveLoop + } + } + } + // Only reset the timers if this packet was actually processed. + // This avoids modifying any state when handling undecryptable packets, + // which could be injected by an attacker. + if !wasProcessed { + continue + } + case <-s.handshakeCompleteChan: + s.handleHandshakeComplete() + } + } + + now := time.Now() + if timeout := s.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && timeout.Before(now) { + // This could cause packets to be retransmitted. + // Check it before trying to send packets. + if err := s.sentPacketHandler.OnLossDetectionTimeout(); err != nil { + s.closeLocal(err) + } + } + + if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() && !now.Before(keepAliveTime) { + // send a PING frame since there is no activity in the connection + s.logger.Debugf("Sending a keep-alive PING to keep the connection alive.") + s.framer.QueueControlFrame(&wire.PingFrame{}) + s.keepAlivePingSent = true + } else if !s.handshakeComplete && now.Sub(s.creationTime) >= s.config.handshakeTimeout() { + s.destroyImpl(qerr.ErrHandshakeTimeout) + continue + } else { + idleTimeoutStartTime := s.idleTimeoutStartTime() + if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) || + (s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.idleTimeout) { + s.destroyImpl(qerr.ErrIdleTimeout) + continue + } + } + + if s.sendQueue.WouldBlock() { + // The send queue is still busy sending out packets. + // Wait until there's space to enqueue new packets. + sendQueueAvailable = s.sendQueue.Available() + continue + } + if err := s.sendPackets(); err != nil { + s.closeLocal(err) + } + if s.sendQueue.WouldBlock() { + sendQueueAvailable = s.sendQueue.Available() + } else { + sendQueueAvailable = nil + } + } + + s.handleCloseError(&closeErr) + if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { + s.tracer.Close() + } + s.logger.Infof("Connection %s closed.", s.logID) + s.cryptoStreamHandler.Close() + s.sendQueue.Close() + s.timer.Stop() + return closeErr.err +} + +// blocks until the early connection can be used +func (s *connection) earlyConnReady() <-chan struct{} { + return s.earlyConnReadyChan +} + +func (s *connection) HandshakeComplete() context.Context { + return s.handshakeCtx +} + +func (s *connection) Context() context.Context { + return s.ctx +} + +func (s *connection) supportsDatagrams() bool { + return s.peerParams.MaxDatagramFrameSize != protocol.InvalidByteCount +} + +func (s *connection) ConnectionState() ConnectionState { + return ConnectionState{ + TLS: s.cryptoStreamHandler.ConnectionState(), + SupportsDatagrams: s.supportsDatagrams(), + } +} + +// Time when the next keep-alive packet should be sent. +// It returns a zero time if no keep-alive should be sent. +func (s *connection) nextKeepAliveTime() time.Time { + if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { + return time.Time{} + } + return s.lastPacketReceivedTime.Add(s.keepAliveInterval) +} + +func (s *connection) maybeResetTimer() { + var deadline time.Time + if !s.handshakeComplete { + deadline = utils.MinTime( + s.creationTime.Add(s.config.handshakeTimeout()), + s.idleTimeoutStartTime().Add(s.config.HandshakeIdleTimeout), + ) + } else { + if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() { + deadline = keepAliveTime + } else { + deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) + } + } + + if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { + deadline = utils.MinTime(deadline, ackAlarm) + } + if lossTime := s.sentPacketHandler.GetLossDetectionTimeout(); !lossTime.IsZero() { + deadline = utils.MinTime(deadline, lossTime) + } + if !s.pacingDeadline.IsZero() { + deadline = utils.MinTime(deadline, s.pacingDeadline) + } + + s.timer.Reset(deadline) +} + +func (s *connection) idleTimeoutStartTime() time.Time { + return utils.MaxTime(s.lastPacketReceivedTime, s.firstAckElicitingPacketAfterIdleSentTime) +} + +func (s *connection) handleHandshakeComplete() { + s.handshakeComplete = true + s.handshakeCompleteChan = nil // prevent this case from ever being selected again + defer s.handshakeCtxCancel() + // Once the handshake completes, we have derived 1-RTT keys. + // There's no point in queueing undecryptable packets for later decryption any more. + s.undecryptablePackets = nil + + s.connIDManager.SetHandshakeComplete() + s.connIDGenerator.SetHandshakeComplete() + + if s.perspective == protocol.PerspectiveClient { + s.applyTransportParameters() + return + } + + s.handleHandshakeConfirmed() + + ticket, err := s.cryptoStreamHandler.GetSessionTicket() + if err != nil { + s.closeLocal(err) + } + if ticket != nil { + s.oneRTTStream.Write(ticket) + for s.oneRTTStream.HasData() { + s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) + } + } + token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr()) + if err != nil { + s.closeLocal(err) + } + s.queueControlFrame(&wire.NewTokenFrame{Token: token}) + s.queueControlFrame(&wire.HandshakeDoneFrame{}) +} + +func (s *connection) handleHandshakeConfirmed() { + s.handshakeConfirmed = true + s.sentPacketHandler.SetHandshakeConfirmed() + s.cryptoStreamHandler.SetHandshakeConfirmed() + + if !s.config.DisablePathMTUDiscovery { + maxPacketSize := s.peerParams.MaxUDPPayloadSize + if maxPacketSize == 0 { + maxPacketSize = protocol.MaxByteCount + } + maxPacketSize = utils.MinByteCount(maxPacketSize, protocol.MaxPacketBufferSize) + s.mtuDiscoverer = newMTUDiscoverer( + s.rttStats, + getMaxPacketSize(s.conn.RemoteAddr()), + maxPacketSize, + func(size protocol.ByteCount) { + s.sentPacketHandler.SetMaxDatagramSize(size) + s.packer.SetMaxPacketSize(size) + }, + ) + } +} + +func (s *connection) handlePacketImpl(rp *receivedPacket) bool { + s.sentPacketHandler.ReceivedBytes(rp.Size()) + + if wire.IsVersionNegotiationPacket(rp.data) { + s.handleVersionNegotiationPacket(rp) + return false + } + + var counter uint8 + var lastConnID protocol.ConnectionID + var processed bool + data := rp.data + p := rp + for len(data) > 0 { + if counter > 0 { + p = p.Clone() + p.data = data + } + + hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen) + if err != nil { + if s.tracer != nil { + dropReason := logging.PacketDropHeaderParseError + if err == wire.ErrUnsupportedVersion { + dropReason = logging.PacketDropUnsupportedVersion + } + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + } + s.logger.Debugf("error parsing packet: %s", err) + break + } + + if hdr.IsLongHeader && hdr.Version != s.version { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) + } + s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) + break + } + + if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) + } + s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) + break + } + lastConnID = hdr.DestConnectionID + + if counter > 0 { + p.buffer.Split() + } + counter++ + + // only log if this actually a coalesced packet + if s.logger.Debug() && (counter > 1 || len(rest) > 0) { + s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) + } + p.data = packetData + if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed { + processed = true + } + data = rest + } + p.buffer.MaybeRelease() + return processed +} + +func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { + var wasQueued bool + + defer func() { + // Put back the packet buffer if the packet wasn't queued for later decryption. + if !wasQueued { + p.buffer.Decrement() + } + }() + + if hdr.Type == protocol.PacketTypeRetry { + return s.handleRetryPacket(hdr, p.data) + } + + // The server can change the source connection ID with the first Handshake packet. + // After this, all packets with a different source connection have to be ignored. + if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && !hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) + } + s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) + return false + } + // drop 0-RTT packets, if we are a client + if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) + } + return false + } + + packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data) + if err != nil { + switch err { + case handshake.ErrKeysDropped: + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropKeyUnavailable) + } + s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", hdr.PacketType(), p.Size()) + case handshake.ErrKeysNotYetAvailable: + // Sealer for this encryption level not yet available. + // Try again later. + wasQueued = true + s.tryQueueingUndecryptablePacket(p, hdr) + case wire.ErrInvalidReservedBits: + s.closeLocal(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: err.Error(), + }) + case handshake.ErrDecryptionFailed: + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err) + default: + var headerErr *headerParseError + if errors.As(err, &headerErr) { + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err) + } else { + // This is an error returned by the AEAD (other than ErrDecryptionFailed). + // For example, a PROTOCOL_VIOLATION due to key updates. + s.closeLocal(err) + } + } + return false + } + + if s.logger.Debug() { + s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, %s", packet.packetNumber, p.Size(), hdr.DestConnectionID, packet.encryptionLevel) + packet.hdr.Log(s.logger) + } + + if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.packetNumber, packet.encryptionLevel) { + s.logger.Debugf("Dropping (potentially) duplicate packet.") + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) + } + return false + } + + if err := s.handleUnpackedPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil { + s.closeLocal(err) + return false + } + return true +} + +func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { + if s.perspective == protocol.PerspectiveServer { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + } + s.logger.Debugf("Ignoring Retry.") + return false + } + if s.receivedFirstPacket { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + } + s.logger.Debugf("Ignoring Retry, since we already received a packet.") + return false + } + destConnID := s.connIDManager.Get() + if hdr.SrcConnectionID.Equal(destConnID) { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + } + s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") + return false + } + // If a token is already set, this means that we already received a Retry from the server. + // Ignore this Retry packet. + if s.receivedRetry { + s.logger.Debugf("Ignoring Retry, since a Retry was already received.") + return false + } + + tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) + if !bytes.Equal(data[len(data)-16:], tag[:]) { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) + } + s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") + return false + } + + if s.logger.Debug() { + s.logger.Debugf("<- Received Retry:") + (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) + s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) + } + if s.tracer != nil { + s.tracer.ReceivedRetry(hdr) + } + newDestConnID := hdr.SrcConnectionID + s.receivedRetry = true + if err := s.sentPacketHandler.ResetForRetry(); err != nil { + s.closeLocal(err) + return false + } + s.handshakeDestConnID = newDestConnID + s.retrySrcConnID = &newDestConnID + s.cryptoStreamHandler.ChangeConnectionID(newDestConnID) + s.packer.SetToken(hdr.Token) + s.connIDManager.ChangeInitialConnID(newDestConnID) + s.scheduleSending() + return true +} + +func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { + if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets + s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + } + return + } + + hdr, supportedVersions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(p.data)) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) + return + } + + for _, v := range supportedVersions { + if v == s.version { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) + } + // The Version Negotiation packet contains the version that we offered. + // This might be a packet sent by an attacker, or it was corrupted. + return + } + } + + s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) + if s.tracer != nil { + s.tracer.ReceivedVersionNegotiationPacket(hdr, supportedVersions) + } + newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) + if !ok { + s.destroyImpl(&VersionNegotiationError{ + Ours: s.config.Versions, + Theirs: supportedVersions, + }) + s.logger.Infof("No compatible QUIC version found.") + return + } + if s.tracer != nil { + s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) + } + + s.logger.Infof("Switching to QUIC version %s.", newVersion) + nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial) + s.destroyImpl(&errCloseForRecreating{ + nextPacketNumber: nextPN, + nextVersion: newVersion, + }) +} + +func (s *connection) handleUnpackedPacket( + packet *unpackedPacket, + ecn protocol.ECN, + rcvTime time.Time, + packetSize protocol.ByteCount, // only for logging +) error { + if len(packet.data) == 0 { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", + } + } + + if !s.receivedFirstPacket { + s.receivedFirstPacket = true + if !s.versionNegotiated && s.tracer != nil { + var clientVersions, serverVersions []protocol.VersionNumber + switch s.perspective { + case protocol.PerspectiveClient: + clientVersions = s.config.Versions + case protocol.PerspectiveServer: + serverVersions = s.config.Versions + } + s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions) + } + // The server can change the source connection ID with the first Handshake packet. + if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + cid := packet.hdr.SrcConnectionID + s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) + s.handshakeDestConnID = cid + s.connIDManager.ChangeInitialConnID(cid) + } + // We create the connection as soon as we receive the first packet from the client. + // We do that before authenticating the packet. + // That means that if the source connection ID was corrupted, + // we might have create a connection with an incorrect source connection ID. + // Once we authenticate the first packet, we need to update it. + if s.perspective == protocol.PerspectiveServer { + if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + s.handshakeDestConnID = packet.hdr.SrcConnectionID + s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) + } + if s.tracer != nil { + s.tracer.StartedConnection( + s.conn.LocalAddr(), + s.conn.RemoteAddr(), + packet.hdr.SrcConnectionID, + packet.hdr.DestConnectionID, + ) + } + } + } + + s.lastPacketReceivedTime = rcvTime + s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} + s.keepAlivePingSent = false + + // Only used for tracing. + // If we're not tracing, this slice will always remain empty. + var frames []wire.Frame + r := bytes.NewReader(packet.data) + var isAckEliciting bool + for { + frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel) + if err != nil { + return err + } + if frame == nil { + break + } + if ackhandler.IsFrameAckEliciting(frame) { + isAckEliciting = true + } + // Only process frames now if we're not logging. + // If we're logging, we need to make sure that the packet_received event is logged first. + if s.tracer == nil { + if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { + return err + } + } else { + frames = append(frames, frame) + } + } + + if s.tracer != nil { + fs := make([]logging.Frame, len(frames)) + for i, frame := range frames { + fs[i] = logutils.ConvertFrame(frame) + } + s.tracer.ReceivedPacket(packet.hdr, packetSize, fs) + for _, frame := range frames { + if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { + return err + } + } + } + + return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) +} + +func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { + var err error + wire.LogFrame(s.logger, f, false) + switch frame := f.(type) { + case *wire.CryptoFrame: + err = s.handleCryptoFrame(frame, encLevel) + case *wire.StreamFrame: + err = s.handleStreamFrame(frame) + case *wire.AckFrame: + err = s.handleAckFrame(frame, encLevel) + case *wire.ConnectionCloseFrame: + s.handleConnectionCloseFrame(frame) + case *wire.ResetStreamFrame: + err = s.handleResetStreamFrame(frame) + case *wire.MaxDataFrame: + s.handleMaxDataFrame(frame) + case *wire.MaxStreamDataFrame: + err = s.handleMaxStreamDataFrame(frame) + case *wire.MaxStreamsFrame: + s.handleMaxStreamsFrame(frame) + case *wire.DataBlockedFrame: + case *wire.StreamDataBlockedFrame: + case *wire.StreamsBlockedFrame: + case *wire.StopSendingFrame: + err = s.handleStopSendingFrame(frame) + case *wire.PingFrame: + case *wire.PathChallengeFrame: + s.handlePathChallengeFrame(frame) + case *wire.PathResponseFrame: + // since we don't send PATH_CHALLENGEs, we don't expect PATH_RESPONSEs + err = errors.New("unexpected PATH_RESPONSE frame") + case *wire.NewTokenFrame: + err = s.handleNewTokenFrame(frame) + case *wire.NewConnectionIDFrame: + err = s.handleNewConnectionIDFrame(frame) + case *wire.RetireConnectionIDFrame: + err = s.handleRetireConnectionIDFrame(frame, destConnID) + case *wire.HandshakeDoneFrame: + err = s.handleHandshakeDoneFrame() + case *wire.DatagramFrame: + err = s.handleDatagramFrame(frame) + default: + err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name()) + } + return err +} + +// handlePacket is called by the server with a new packet +func (s *connection) handlePacket(p *receivedPacket) { + // Discard packets once the amount of queued packets is larger than + // the channel size, protocol.MaxConnUnprocessedPackets + select { + case s.receivedPackets <- p: + default: + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + } + } +} + +func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) { + if frame.IsApplicationError { + s.closeRemote(&qerr.ApplicationError{ + Remote: true, + ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode), + ErrorMessage: frame.ReasonPhrase, + }) + return + } + s.closeRemote(&qerr.TransportError{ + Remote: true, + ErrorCode: qerr.TransportErrorCode(frame.ErrorCode), + FrameType: frame.FrameType, + ErrorMessage: frame.ReasonPhrase, + }) +} + +func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { + encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) + if err != nil { + return err + } + if encLevelChanged { + // Queue all packets for decryption that have been undecryptable so far. + s.undecryptablePacketsToProcess = s.undecryptablePackets + s.undecryptablePackets = nil + } + return nil +} + +func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error { + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // Stream is closed and already garbage collected + // ignore this StreamFrame + return nil + } + return str.handleStreamFrame(frame) +} + +func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) { + s.connFlowController.UpdateSendWindow(frame.MaximumData) +} + +func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + str.updateSendWindow(frame.MaximumStreamData) + return nil +} + +func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) { + s.streamsMap.HandleMaxStreamsFrame(frame) +} + +func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { + str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + return str.handleResetStreamFrame(frame) +} + +func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error { + str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) + if err != nil { + return err + } + if str == nil { + // stream is closed and already garbage collected + return nil + } + str.handleStopSendingFrame(frame) + return nil +} + +func (s *connection) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { + s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) +} + +func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error { + if s.perspective == protocol.PerspectiveServer { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received NEW_TOKEN frame from the client", + } + } + if s.config.TokenStore != nil { + s.config.TokenStore.Put(s.tokenStoreKey, &ClientToken{data: frame.Token}) + } + return nil +} + +func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error { + return s.connIDManager.Add(f) +} + +func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { + return s.connIDGenerator.Retire(f.SequenceNumber, destConnID) +} + +func (s *connection) handleHandshakeDoneFrame() error { + if s.perspective == protocol.PerspectiveServer { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received a HANDSHAKE_DONE frame", + } + } + if !s.handshakeConfirmed { + s.handleHandshakeConfirmed() + } + return nil +} + +func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { + acked1RTTPacket, err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime) + if err != nil { + return err + } + if !acked1RTTPacket { + return nil + } + if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed { + s.handleHandshakeConfirmed() + } + return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) +} + +func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error { + if f.Length(s.version) > protocol.MaxDatagramFrameSize { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "DATAGRAM frame too large", + } + } + s.datagramQueue.HandleDatagramFrame(f) + return nil +} + +// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error +func (s *connection) closeLocal(e error) { + s.closeOnce.Do(func() { + if e == nil { + s.logger.Infof("Closing connection.") + } else { + s.logger.Errorf("Closing connection with error: %s", e) + } + s.closeChan <- closeError{err: e, immediate: false, remote: false} + }) +} + +// destroy closes the connection without sending the error on the wire +func (s *connection) destroy(e error) { + s.destroyImpl(e) + <-s.ctx.Done() +} + +func (s *connection) destroyImpl(e error) { + s.closeOnce.Do(func() { + if nerr, ok := e.(net.Error); ok && nerr.Timeout() { + s.logger.Errorf("Destroying connection: %s", e) + } else { + s.logger.Errorf("Destroying connection with error: %s", e) + } + s.closeChan <- closeError{err: e, immediate: true, remote: false} + }) +} + +func (s *connection) closeRemote(e error) { + s.closeOnce.Do(func() { + s.logger.Errorf("Peer closed connection with error: %s", e) + s.closeChan <- closeError{err: e, immediate: true, remote: true} + }) +} + +// Close the connection. It sends a NO_ERROR application error. +// It waits until the run loop has stopped before returning +func (s *connection) shutdown() { + s.closeLocal(nil) + <-s.ctx.Done() +} + +func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { + s.closeLocal(&qerr.ApplicationError{ + ErrorCode: code, + ErrorMessage: desc, + }) + <-s.ctx.Done() + return nil +} + +func (s *connection) handleCloseError(closeErr *closeError) { + e := closeErr.err + if e == nil { + e = &qerr.ApplicationError{} + } else { + defer func() { + closeErr.err = e + }() + } + + var ( + statelessResetErr *StatelessResetError + versionNegotiationErr *VersionNegotiationError + recreateErr *errCloseForRecreating + applicationErr *ApplicationError + transportErr *TransportError + ) + switch { + case errors.Is(e, qerr.ErrIdleTimeout), + errors.Is(e, qerr.ErrHandshakeTimeout), + errors.As(e, &statelessResetErr), + errors.As(e, &versionNegotiationErr), + errors.As(e, &recreateErr), + errors.As(e, &applicationErr), + errors.As(e, &transportErr): + default: + e = &qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: e.Error(), + } + } + + s.streamsMap.CloseWithError(e) + s.connIDManager.Close() + if s.datagramQueue != nil { + s.datagramQueue.CloseWithError(e) + } + + if s.tracer != nil && !errors.As(e, &recreateErr) { + s.tracer.ClosedConnection(e) + } + + // If this is a remote close we're done here + if closeErr.remote { + s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective)) + return + } + if closeErr.immediate { + s.connIDGenerator.RemoveAll() + return + } + // Don't send out any CONNECTION_CLOSE if this is an error that occurred + // before we even sent out the first packet. + if s.perspective == protocol.PerspectiveClient && !s.sentFirstPacket { + s.connIDGenerator.RemoveAll() + return + } + connClosePacket, err := s.sendConnectionClose(e) + if err != nil { + s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) + } + cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger) + s.connIDGenerator.ReplaceWithClosed(cs) +} + +func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { + s.sentPacketHandler.DropPackets(encLevel) + s.receivedPacketHandler.DropPackets(encLevel) + if s.tracer != nil { + s.tracer.DroppedEncryptionLevel(encLevel) + } + if encLevel == protocol.Encryption0RTT { + s.streamsMap.ResetFor0RTT() + if err := s.connFlowController.Reset(); err != nil { + s.closeLocal(err) + } + if err := s.framer.Handle0RTTRejection(); err != nil { + s.closeLocal(err) + } + } +} + +// is called for the client, when restoring transport parameters saved for 0-RTT +func (s *connection) restoreTransportParameters(params *wire.TransportParameters) { + if s.logger.Debug() { + s.logger.Debugf("Restoring Transport Parameters: %s", params) + } + + s.peerParams = params + s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) + s.connFlowController.UpdateSendWindow(params.InitialMaxData) + s.streamsMap.UpdateLimits(params) +} + +func (s *connection) handleTransportParameters(params *wire.TransportParameters) { + if err := s.checkTransportParameters(params); err != nil { + s.closeLocal(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + }) + } + s.peerParams = params + // On the client side we have to wait for handshake completion. + // During a 0-RTT connection, we are only allowed to use the new transport parameters for 1-RTT packets. + if s.perspective == protocol.PerspectiveServer { + s.applyTransportParameters() + // On the server side, the early connection is ready as soon as we processed + // the client's transport parameters. + close(s.earlyConnReadyChan) + } +} + +func (s *connection) checkTransportParameters(params *wire.TransportParameters) error { + if s.logger.Debug() { + s.logger.Debugf("Processed Transport Parameters: %s", params) + } + if s.tracer != nil { + s.tracer.ReceivedTransportParameters(params) + } + + // check the initial_source_connection_id + if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { + return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID) + } + + if s.perspective == protocol.PerspectiveServer { + return nil + } + // check the original_destination_connection_id + if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { + return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID) + } + if s.retrySrcConnID != nil { // a Retry was performed + if params.RetrySourceConnectionID == nil { + return errors.New("missing retry_source_connection_id") + } + if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { + return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID) + } + } else if params.RetrySourceConnectionID != nil { + return errors.New("received retry_source_connection_id, although no Retry was performed") + } + return nil +} + +func (s *connection) applyTransportParameters() { + params := s.peerParams + // Our local idle timeout will always be > 0. + s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) + s.keepAliveInterval = utils.MinDuration(s.config.KeepAlivePeriod, utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) + s.streamsMap.UpdateLimits(params) + s.packer.HandleTransportParameters(params) + s.frameParser.SetAckDelayExponent(params.AckDelayExponent) + s.connFlowController.UpdateSendWindow(params.InitialMaxData) + s.rttStats.SetMaxAckDelay(params.MaxAckDelay) + s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) + if params.StatelessResetToken != nil { + s.connIDManager.SetStatelessResetToken(*params.StatelessResetToken) + } + // We don't support connection migration yet, so we don't have any use for the preferred_address. + if params.PreferredAddress != nil { + // Retire the connection ID. + s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken) + } +} + +func (s *connection) sendPackets() error { + s.pacingDeadline = time.Time{} + + var sentPacket bool // only used in for packets sent in send mode SendAny + for { + sendMode := s.sentPacketHandler.SendMode() + if sendMode == ackhandler.SendAny && s.handshakeComplete && !s.sentPacketHandler.HasPacingBudget() { + deadline := s.sentPacketHandler.TimeUntilSend() + if deadline.IsZero() { + deadline = deadlineSendImmediately + } + s.pacingDeadline = deadline + // Allow sending of an ACK if we're pacing limit (if we haven't sent out a packet yet). + // This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate) + // sends enough ACKs to allow its peer to utilize the bandwidth. + if sentPacket { + return nil + } + sendMode = ackhandler.SendAck + } + switch sendMode { + case ackhandler.SendNone: + return nil + case ackhandler.SendAck: + // If we already sent packets, and the send mode switches to SendAck, + // as we've just become congestion limited. + // There's no need to try to send an ACK at this moment. + if sentPacket { + return nil + } + // We can at most send a single ACK only packet. + // There will only be a new ACK after receiving new packets. + // SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer. + return s.maybeSendAckOnlyPacket() + case ackhandler.SendPTOInitial: + if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { + return err + } + case ackhandler.SendPTOHandshake: + if err := s.sendProbePacket(protocol.EncryptionHandshake); err != nil { + return err + } + case ackhandler.SendPTOAppData: + if err := s.sendProbePacket(protocol.Encryption1RTT); err != nil { + return err + } + case ackhandler.SendAny: + sent, err := s.sendPacket() + if err != nil || !sent { + return err + } + sentPacket = true + default: + return fmt.Errorf("BUG: invalid send mode %d", sendMode) + } + // Prioritize receiving of packets over sending out more packets. + if len(s.receivedPackets) > 0 { + s.pacingDeadline = deadlineSendImmediately + return nil + } + if s.sendQueue.WouldBlock() { + return nil + } + } +} + +func (s *connection) maybeSendAckOnlyPacket() error { + packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed) + if err != nil { + return err + } + if packet == nil { + return nil + } + s.sendPackedPacket(packet, time.Now()) + return nil +} + +func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { + // Queue probe packets until we actually send out a packet, + // or until there are no more packets to queue. + var packet *packedPacket + for { + if wasQueued := s.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued { + break + } + var err error + packet, err = s.packer.MaybePackProbePacket(encLevel) + if err != nil { + return err + } + if packet != nil { + break + } + } + if packet == nil { + //nolint:exhaustive // Cannot send probe packets for 0-RTT. + switch encLevel { + case protocol.EncryptionInitial: + s.retransmissionQueue.AddInitial(&wire.PingFrame{}) + case protocol.EncryptionHandshake: + s.retransmissionQueue.AddHandshake(&wire.PingFrame{}) + case protocol.Encryption1RTT: + s.retransmissionQueue.AddAppData(&wire.PingFrame{}) + default: + panic("unexpected encryption level") + } + var err error + packet, err = s.packer.MaybePackProbePacket(encLevel) + if err != nil { + return err + } + } + if packet == nil || packet.packetContents == nil { + return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) + } + s.sendPackedPacket(packet, time.Now()) + return nil +} + +func (s *connection) sendPacket() (bool, error) { + if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { + s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) + } + s.windowUpdateQueue.QueueAll() + + now := time.Now() + if !s.handshakeConfirmed { + packet, err := s.packer.PackCoalescedPacket() + if err != nil || packet == nil { + return false, err + } + s.sentFirstPacket = true + s.logCoalescedPacket(packet) + for _, p := range packet.packets { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now + } + s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) + } + s.connIDManager.SentPacket() + s.sendQueue.Send(packet.buffer) + return true, nil + } + if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { + packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing()) + if err != nil { + return false, err + } + s.sendPackedPacket(packet, now) + return true, nil + } + packet, err := s.packer.PackPacket() + if err != nil || packet == nil { + return false, err + } + s.sendPackedPacket(packet, now) + return true, nil +} + +func (s *connection) sendPackedPacket(packet *packedPacket, now time.Time) { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now + } + s.logPacket(packet) + s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(now, s.retransmissionQueue)) + s.connIDManager.SentPacket() + s.sendQueue.Send(packet.buffer) +} + +func (s *connection) sendConnectionClose(e error) ([]byte, error) { + var packet *coalescedPacket + var err error + var transportErr *qerr.TransportError + var applicationErr *qerr.ApplicationError + if errors.As(e, &transportErr) { + packet, err = s.packer.PackConnectionClose(transportErr) + } else if errors.As(e, &applicationErr) { + packet, err = s.packer.PackApplicationClose(applicationErr) + } else { + packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ + ErrorCode: qerr.InternalError, + ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), + }) + } + if err != nil { + return nil, err + } + s.logCoalescedPacket(packet) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data) +} + +func (s *connection) logPacketContents(p *packetContents) { + // tracing + if s.tracer != nil { + frames := make([]logging.Frame, 0, len(p.frames)) + for _, f := range p.frames { + frames = append(frames, logutils.ConvertFrame(f.Frame)) + } + s.tracer.SentPacket(p.header, p.length, p.ack, frames) + } + + // quic-go logging + if !s.logger.Debug() { + return + } + p.header.Log(s.logger) + if p.ack != nil { + wire.LogFrame(s.logger, p.ack, true) + } + for _, frame := range p.frames { + wire.LogFrame(s.logger, frame.Frame, true) + } +} + +func (s *connection) logCoalescedPacket(packet *coalescedPacket) { + if s.logger.Debug() { + if len(packet.packets) > 1 { + s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.packets), packet.buffer.Len(), s.logID) + } else { + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.packets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.packets[0].EncryptionLevel()) + } + } + for _, p := range packet.packets { + s.logPacketContents(p) + } +} + +func (s *connection) logPacket(packet *packedPacket) { + if s.logger.Debug() { + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel()) + } + s.logPacketContents(packet.packetContents) +} + +// AcceptStream returns the next stream openend by the peer +func (s *connection) AcceptStream(ctx context.Context) (Stream, error) { + return s.streamsMap.AcceptStream(ctx) +} + +func (s *connection) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { + return s.streamsMap.AcceptUniStream(ctx) +} + +// OpenStream opens a stream +func (s *connection) OpenStream() (Stream, error) { + return s.streamsMap.OpenStream() +} + +func (s *connection) OpenStreamSync(ctx context.Context) (Stream, error) { + return s.streamsMap.OpenStreamSync(ctx) +} + +func (s *connection) OpenUniStream() (SendStream, error) { + return s.streamsMap.OpenUniStream() +} + +func (s *connection) OpenUniStreamSync(ctx context.Context) (SendStream, error) { + return s.streamsMap.OpenUniStreamSync(ctx) +} + +func (s *connection) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { + initialSendWindow := s.peerParams.InitialMaxStreamDataUni + if id.Type() == protocol.StreamTypeBidi { + if id.InitiatedBy() == s.perspective { + initialSendWindow = s.peerParams.InitialMaxStreamDataBidiRemote + } else { + initialSendWindow = s.peerParams.InitialMaxStreamDataBidiLocal + } + } + return flowcontrol.NewStreamFlowController( + id, + s.connFlowController, + protocol.ByteCount(s.config.InitialStreamReceiveWindow), + protocol.ByteCount(s.config.MaxStreamReceiveWindow), + initialSendWindow, + s.onHasStreamWindowUpdate, + s.rttStats, + s.logger, + ) +} + +// scheduleSending signals that we have data for sending +func (s *connection) scheduleSending() { + select { + case s.sendingScheduled <- struct{}{}: + default: + } +} + +func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, hdr *wire.Header) { + if s.handshakeComplete { + panic("shouldn't queue undecryptable packets after handshake completion") + } + if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDOSPrevention) + } + s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) + return + } + s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) + if s.tracer != nil { + s.tracer.BufferedPacket(logging.PacketTypeFromHeader(hdr)) + } + s.undecryptablePackets = append(s.undecryptablePackets, p) +} + +func (s *connection) queueControlFrame(f wire.Frame) { + s.framer.QueueControlFrame(f) + s.scheduleSending() +} + +func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID) { + s.windowUpdateQueue.AddStream(id) + s.scheduleSending() +} + +func (s *connection) onHasConnectionWindowUpdate() { + s.windowUpdateQueue.AddConnection() + s.scheduleSending() +} + +func (s *connection) onHasStreamData(id protocol.StreamID) { + s.framer.AddActiveStream(id) + s.scheduleSending() +} + +func (s *connection) onStreamCompleted(id protocol.StreamID) { + if err := s.streamsMap.DeleteStream(id); err != nil { + s.closeLocal(err) + } +} + +func (s *connection) SendMessage(p []byte) error { + f := &wire.DatagramFrame{DataLenPresent: true} + if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { + return errors.New("message too large") + } + f.Data = make([]byte, len(p)) + copy(f.Data, p) + return s.datagramQueue.AddAndWait(f) +} + +func (s *connection) ReceiveMessage() ([]byte, error) { + return s.datagramQueue.Receive() +} + +func (s *connection) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *connection) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *connection) getPerspective() protocol.Perspective { + return s.perspective +} + +func (s *connection) GetVersion() protocol.VersionNumber { + return s.version +} + +func (s *connection) NextConnection() Connection { + <-s.HandshakeComplete().Done() + s.streamsMap.UseResetMaps() + return s +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go new file mode 100644 index 00000000..36e21d33 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go @@ -0,0 +1,115 @@ +package quic + +import ( + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type cryptoStream interface { + // for receiving data + HandleCryptoFrame(*wire.CryptoFrame) error + GetCryptoData() []byte + Finish() error + // for sending data + io.Writer + HasData() bool + PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame +} + +type cryptoStreamImpl struct { + queue *frameSorter + msgBuf []byte + + highestOffset protocol.ByteCount + finished bool + + writeOffset protocol.ByteCount + writeBuf []byte +} + +func newCryptoStream() cryptoStream { + return &cryptoStreamImpl{queue: newFrameSorter()} +} + +func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { + highestOffset := f.Offset + protocol.ByteCount(len(f.Data)) + if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset { + return &qerr.TransportError{ + ErrorCode: qerr.CryptoBufferExceeded, + ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset), + } + } + if s.finished { + if highestOffset > s.highestOffset { + // reject crypto data received after this stream was already finished + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received crypto data after change of encryption level", + } + } + // ignore data with a smaller offset than the highest received + // could e.g. be a retransmission + return nil + } + s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset) + if err := s.queue.Push(f.Data, f.Offset, nil); err != nil { + return err + } + for { + _, data, _ := s.queue.Pop() + if data == nil { + return nil + } + s.msgBuf = append(s.msgBuf, data...) + } +} + +// GetCryptoData retrieves data that was received in CRYPTO frames +func (s *cryptoStreamImpl) GetCryptoData() []byte { + if len(s.msgBuf) < 4 { + return nil + } + msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) + if len(s.msgBuf) < msgLen { + return nil + } + msg := make([]byte, msgLen) + copy(msg, s.msgBuf[:msgLen]) + s.msgBuf = s.msgBuf[msgLen:] + return msg +} + +func (s *cryptoStreamImpl) Finish() error { + if s.queue.HasMoreData() { + return &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "encryption level changed, but crypto stream has more data to read", + } + } + s.finished = true + return nil +} + +// Writes writes data that should be sent out in CRYPTO frames +func (s *cryptoStreamImpl) Write(p []byte) (int, error) { + s.writeBuf = append(s.writeBuf, p...) + return len(p), nil +} + +func (s *cryptoStreamImpl) HasData() bool { + return len(s.writeBuf) > 0 +} + +func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { + f := &wire.CryptoFrame{Offset: s.writeOffset} + n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) + f.Data = s.writeBuf[:n] + s.writeBuf = s.writeBuf[n:] + s.writeOffset += n + return f +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream_manager.go b/vendor/github.com/lucas-clemente/quic-go/crypto_stream_manager.go new file mode 100644 index 00000000..66f90049 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/crypto_stream_manager.go @@ -0,0 +1,61 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type cryptoDataHandler interface { + HandleMessage([]byte, protocol.EncryptionLevel) bool +} + +type cryptoStreamManager struct { + cryptoHandler cryptoDataHandler + + initialStream cryptoStream + handshakeStream cryptoStream + oneRTTStream cryptoStream +} + +func newCryptoStreamManager( + cryptoHandler cryptoDataHandler, + initialStream cryptoStream, + handshakeStream cryptoStream, + oneRTTStream cryptoStream, +) *cryptoStreamManager { + return &cryptoStreamManager{ + cryptoHandler: cryptoHandler, + initialStream: initialStream, + handshakeStream: handshakeStream, + oneRTTStream: oneRTTStream, + } +} + +func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) { + var str cryptoStream + //nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets. + switch encLevel { + case protocol.EncryptionInitial: + str = m.initialStream + case protocol.EncryptionHandshake: + str = m.handshakeStream + case protocol.Encryption1RTT: + str = m.oneRTTStream + default: + return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) + } + if err := str.HandleCryptoFrame(frame); err != nil { + return false, err + } + for { + data := str.GetCryptoData() + if data == nil { + return false, nil + } + if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { + return true, str.Finish() + } + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/datagram_queue.go b/vendor/github.com/lucas-clemente/quic-go/datagram_queue.go new file mode 100644 index 00000000..b1cbbf6d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/datagram_queue.go @@ -0,0 +1,87 @@ +package quic + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type datagramQueue struct { + sendQueue chan *wire.DatagramFrame + rcvQueue chan []byte + + closeErr error + closed chan struct{} + + hasData func() + + dequeued chan struct{} + + logger utils.Logger +} + +func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { + return &datagramQueue{ + hasData: hasData, + sendQueue: make(chan *wire.DatagramFrame, 1), + rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen), + dequeued: make(chan struct{}), + closed: make(chan struct{}), + logger: logger, + } +} + +// AddAndWait queues a new DATAGRAM frame for sending. +// It blocks until the frame has been dequeued. +func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { + select { + case h.sendQueue <- f: + h.hasData() + case <-h.closed: + return h.closeErr + } + + select { + case <-h.dequeued: + return nil + case <-h.closed: + return h.closeErr + } +} + +// Get dequeues a DATAGRAM frame for sending. +func (h *datagramQueue) Get() *wire.DatagramFrame { + select { + case f := <-h.sendQueue: + h.dequeued <- struct{}{} + return f + default: + return nil + } +} + +// HandleDatagramFrame handles a received DATAGRAM frame. +func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { + data := make([]byte, len(f.Data)) + copy(data, f.Data) + select { + case h.rcvQueue <- data: + default: + h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) + } +} + +// Receive gets a received DATAGRAM frame. +func (h *datagramQueue) Receive() ([]byte, error) { + select { + case data := <-h.rcvQueue: + return data, nil + case <-h.closed: + return nil, h.closeErr + } +} + +func (h *datagramQueue) CloseWithError(e error) { + h.closeErr = e + close(h.closed) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/errors.go b/vendor/github.com/lucas-clemente/quic-go/errors.go new file mode 100644 index 00000000..0c9f0004 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/errors.go @@ -0,0 +1,58 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/qerr" +) + +type ( + TransportError = qerr.TransportError + ApplicationError = qerr.ApplicationError + VersionNegotiationError = qerr.VersionNegotiationError + StatelessResetError = qerr.StatelessResetError + IdleTimeoutError = qerr.IdleTimeoutError + HandshakeTimeoutError = qerr.HandshakeTimeoutError +) + +type ( + TransportErrorCode = qerr.TransportErrorCode + ApplicationErrorCode = qerr.ApplicationErrorCode + StreamErrorCode = qerr.StreamErrorCode +) + +const ( + NoError = qerr.NoError + InternalError = qerr.InternalError + ConnectionRefused = qerr.ConnectionRefused + FlowControlError = qerr.FlowControlError + StreamLimitError = qerr.StreamLimitError + StreamStateError = qerr.StreamStateError + FinalSizeError = qerr.FinalSizeError + FrameEncodingError = qerr.FrameEncodingError + TransportParameterError = qerr.TransportParameterError + ConnectionIDLimitError = qerr.ConnectionIDLimitError + ProtocolViolation = qerr.ProtocolViolation + InvalidToken = qerr.InvalidToken + ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode + CryptoBufferExceeded = qerr.CryptoBufferExceeded + KeyUpdateError = qerr.KeyUpdateError + AEADLimitReached = qerr.AEADLimitReached + NoViablePathError = qerr.NoViablePathError +) + +// A StreamError is used for Stream.CancelRead and Stream.CancelWrite. +// It is also returned from Stream.Read and Stream.Write if the peer canceled reading or writing. +type StreamError struct { + StreamID StreamID + ErrorCode StreamErrorCode +} + +func (e *StreamError) Is(target error) bool { + _, ok := target.(*StreamError) + return ok +} + +func (e *StreamError) Error() string { + return fmt.Sprintf("stream %d canceled with error code %d", e.StreamID, e.ErrorCode) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/frame_sorter.go b/vendor/github.com/lucas-clemente/quic-go/frame_sorter.go new file mode 100644 index 00000000..aeafa7d4 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/frame_sorter.go @@ -0,0 +1,224 @@ +package quic + +import ( + "errors" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type frameSorterEntry struct { + Data []byte + DoneCb func() +} + +type frameSorter struct { + queue map[protocol.ByteCount]frameSorterEntry + readPos protocol.ByteCount + gaps *utils.ByteIntervalList +} + +var errDuplicateStreamData = errors.New("duplicate stream data") + +func newFrameSorter() *frameSorter { + s := frameSorter{ + gaps: utils.NewByteIntervalList(), + queue: make(map[protocol.ByteCount]frameSorterEntry), + } + s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) + return &s +} + +func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error { + err := s.push(data, offset, doneCb) + if err == errDuplicateStreamData { + if doneCb != nil { + doneCb() + } + return nil + } + return err +} + +func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error { + if len(data) == 0 { + return errDuplicateStreamData + } + + start := offset + end := offset + protocol.ByteCount(len(data)) + + if end <= s.gaps.Front().Value.Start { + return errDuplicateStreamData + } + + startGap, startsInGap := s.findStartGap(start) + endGap, endsInGap := s.findEndGap(startGap, end) + + startGapEqualsEndGap := startGap == endGap + + if (startGapEqualsEndGap && end <= startGap.Value.Start) || + (!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) { + return errDuplicateStreamData + } + + startGapNext := startGap.Next() + startGapEnd := startGap.Value.End // save it, in case startGap is modified + endGapStart := endGap.Value.Start // save it, in case endGap is modified + endGapEnd := endGap.Value.End // save it, in case endGap is modified + var adjustedStartGapEnd bool + var wasCut bool + + pos := start + var hasReplacedAtLeastOne bool + for { + oldEntry, ok := s.queue[pos] + if !ok { + break + } + oldEntryLen := protocol.ByteCount(len(oldEntry.Data)) + if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) { + // The existing frame is shorter than the new frame. Replace it. + delete(s.queue, pos) + pos += oldEntryLen + hasReplacedAtLeastOne = true + if oldEntry.DoneCb != nil { + oldEntry.DoneCb() + } + } else { + if !hasReplacedAtLeastOne { + return errDuplicateStreamData + } + // The existing frame is longer than the new frame. + // Cut the new frame such that the end aligns with the start of the existing frame. + data = data[:pos-start] + end = pos + wasCut = true + break + } + } + + if !startsInGap && !hasReplacedAtLeastOne { + // cut the frame, such that it starts at the start of the gap + data = data[startGap.Value.Start-start:] + start = startGap.Value.Start + wasCut = true + } + if start <= startGap.Value.Start { + if end >= startGap.Value.End { + // The frame covers the whole startGap. Delete the gap. + s.gaps.Remove(startGap) + } else { + startGap.Value.Start = end + } + } else if !hasReplacedAtLeastOne { + startGap.Value.End = start + adjustedStartGapEnd = true + } + + if !startGapEqualsEndGap { + s.deleteConsecutive(startGapEnd) + var nextGap *utils.ByteIntervalElement + for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap { + nextGap = gap.Next() + s.deleteConsecutive(gap.Value.End) + s.gaps.Remove(gap) + } + } + + if !endsInGap && start != endGapEnd && end > endGapEnd { + // cut the frame, such that it ends at the end of the gap + data = data[:endGapEnd-start] + end = endGapEnd + wasCut = true + } + if end == endGapEnd { + if !startGapEqualsEndGap { + // The frame covers the whole endGap. Delete the gap. + s.gaps.Remove(endGap) + } + } else { + if startGapEqualsEndGap && adjustedStartGapEnd { + // The frame split the existing gap into two. + s.gaps.InsertAfter(utils.ByteInterval{Start: end, End: startGapEnd}, startGap) + } else if !startGapEqualsEndGap { + endGap.Value.Start = end + } + } + + if wasCut && len(data) < protocol.MinStreamFrameBufferSize { + newData := make([]byte, len(data)) + copy(newData, data) + data = newData + if doneCb != nil { + doneCb() + doneCb = nil + } + } + + if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps { + return errors.New("too many gaps in received data") + } + + s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb} + return nil +} + +func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) { + for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { + if offset >= gap.Value.Start && offset <= gap.Value.End { + return gap, true + } + if offset < gap.Value.Start { + return gap, false + } + } + panic("no gap found") +} + +func (s *frameSorter) findEndGap(startGap *utils.ByteIntervalElement, offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) { + for gap := startGap; gap != nil; gap = gap.Next() { + if offset >= gap.Value.Start && offset < gap.Value.End { + return gap, true + } + if offset < gap.Value.Start { + return gap.Prev(), false + } + } + panic("no gap found") +} + +// deleteConsecutive deletes consecutive frames from the queue, starting at pos +func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) { + for { + oldEntry, ok := s.queue[pos] + if !ok { + break + } + oldEntryLen := protocol.ByteCount(len(oldEntry.Data)) + delete(s.queue, pos) + if oldEntry.DoneCb != nil { + oldEntry.DoneCb() + } + pos += oldEntryLen + } +} + +func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) { + entry, ok := s.queue[s.readPos] + if !ok { + return s.readPos, nil, nil + } + delete(s.queue, s.readPos) + offset := s.readPos + s.readPos += protocol.ByteCount(len(entry.Data)) + if s.gaps.Front().Value.End <= s.readPos { + panic("frame sorter BUG: read position higher than a gap") + } + return offset, entry.Data, entry.DoneCb +} + +// HasMoreData says if there is any more data queued at *any* offset. +func (s *frameSorter) HasMoreData() bool { + return len(s.queue) > 0 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/framer.go b/vendor/github.com/lucas-clemente/quic-go/framer.go new file mode 100644 index 00000000..29d36b85 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/framer.go @@ -0,0 +1,171 @@ +package quic + +import ( + "errors" + "sync" + + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +type framer interface { + HasData() bool + + QueueControlFrame(wire.Frame) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + + AddActiveStream(protocol.StreamID) + AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + + Handle0RTTRejection() error +} + +type framerI struct { + mutex sync.Mutex + + streamGetter streamGetter + version protocol.VersionNumber + + activeStreams map[protocol.StreamID]struct{} + streamQueue []protocol.StreamID + + controlFrameMutex sync.Mutex + controlFrames []wire.Frame +} + +var _ framer = &framerI{} + +func newFramer( + streamGetter streamGetter, + v protocol.VersionNumber, +) framer { + return &framerI{ + streamGetter: streamGetter, + activeStreams: make(map[protocol.StreamID]struct{}), + version: v, + } +} + +func (f *framerI) HasData() bool { + f.mutex.Lock() + hasData := len(f.streamQueue) > 0 + f.mutex.Unlock() + if hasData { + return true + } + f.controlFrameMutex.Lock() + hasData = len(f.controlFrames) > 0 + f.controlFrameMutex.Unlock() + return hasData +} + +func (f *framerI) QueueControlFrame(frame wire.Frame) { + f.controlFrameMutex.Lock() + f.controlFrames = append(f.controlFrames, frame) + f.controlFrameMutex.Unlock() +} + +func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + var length protocol.ByteCount + f.controlFrameMutex.Lock() + for len(f.controlFrames) > 0 { + frame := f.controlFrames[len(f.controlFrames)-1] + frameLen := frame.Length(f.version) + if length+frameLen > maxLen { + break + } + frames = append(frames, ackhandler.Frame{Frame: frame}) + length += frameLen + f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] + } + f.controlFrameMutex.Unlock() + return frames, length +} + +func (f *framerI) AddActiveStream(id protocol.StreamID) { + f.mutex.Lock() + if _, ok := f.activeStreams[id]; !ok { + f.streamQueue = append(f.streamQueue, id) + f.activeStreams[id] = struct{}{} + } + f.mutex.Unlock() +} + +func (f *framerI) AppendStreamFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + var length protocol.ByteCount + var lastFrame *ackhandler.Frame + f.mutex.Lock() + // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet + numActiveStreams := len(f.streamQueue) + for i := 0; i < numActiveStreams; i++ { + if protocol.MinStreamFrameSize+length > maxLen { + break + } + id := f.streamQueue[0] + f.streamQueue = f.streamQueue[1:] + // This should never return an error. Better check it anyway. + // The stream will only be in the streamQueue, if it enqueued itself there. + str, err := f.streamGetter.GetOrOpenSendStream(id) + // The stream can be nil if it completed after it said it had data. + if str == nil || err != nil { + delete(f.activeStreams, id) + continue + } + remainingLen := maxLen - length + // For the last STREAM frame, we'll remove the DataLen field later. + // Therefore, we can pretend to have more bytes available when popping + // the STREAM frame (which will always have the DataLen set). + remainingLen += quicvarint.Len(uint64(remainingLen)) + frame, hasMoreData := str.popStreamFrame(remainingLen) + if hasMoreData { // put the stream back in the queue (at the end) + f.streamQueue = append(f.streamQueue, id) + } else { // no more data to send. Stream is not active any more + delete(f.activeStreams, id) + } + // The frame can be nil + // * if the receiveStream was canceled after it said it had data + // * the remaining size doesn't allow us to add another STREAM frame + if frame == nil { + continue + } + frames = append(frames, *frame) + length += frame.Length(f.version) + lastFrame = frame + } + f.mutex.Unlock() + if lastFrame != nil { + lastFrameLen := lastFrame.Length(f.version) + // account for the smaller size of the last STREAM frame + lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false + length += lastFrame.Length(f.version) - lastFrameLen + } + return frames, length +} + +func (f *framerI) Handle0RTTRejection() error { + f.mutex.Lock() + defer f.mutex.Unlock() + + f.controlFrameMutex.Lock() + f.streamQueue = f.streamQueue[:0] + for id := range f.activeStreams { + delete(f.activeStreams, id) + } + var j int + for i, frame := range f.controlFrames { + switch frame.(type) { + case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame: + return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT") + case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame: + continue + default: + f.controlFrames[j] = f.controlFrames[i] + j++ + } + } + f.controlFrames = f.controlFrames[:j] + f.controlFrameMutex.Unlock() + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/body.go b/vendor/github.com/lucas-clemente/quic-go/http3/body.go new file mode 100644 index 00000000..b3d1afd7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/body.go @@ -0,0 +1,130 @@ +package http3 + +import ( + "context" + "io" + "net" + + "github.com/lucas-clemente/quic-go" +) + +// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by: +// * for the server: the http.Request.Body +// * for the client: the http.Response.Body +// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set. +// When a stream is taken over, it's the caller's responsibility to close the stream. +type HTTPStreamer interface { + HTTPStream() Stream +} + +type StreamCreator interface { + OpenStream() (quic.Stream, error) + OpenStreamSync(context.Context) (quic.Stream, error) + OpenUniStream() (quic.SendStream, error) + OpenUniStreamSync(context.Context) (quic.SendStream, error) + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +var _ StreamCreator = quic.Connection(nil) + +// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body. +// It is used by WebTransport to create WebTransport streams after a session has been established. +type Hijacker interface { + StreamCreator() StreamCreator +} + +// The body of a http.Request or http.Response. +type body struct { + str quic.Stream + + wasHijacked bool // set when HTTPStream is called +} + +var ( + _ io.ReadCloser = &body{} + _ HTTPStreamer = &body{} +) + +func newRequestBody(str Stream) *body { + return &body{str: str} +} + +func (r *body) HTTPStream() Stream { + r.wasHijacked = true + return r.str +} + +func (r *body) wasStreamHijacked() bool { + return r.wasHijacked +} + +func (r *body) Read(b []byte) (int, error) { + return r.str.Read(b) +} + +func (r *body) Close() error { + r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + return nil +} + +type hijackableBody struct { + body + conn quic.Connection // only needed to implement Hijacker + + // only set for the http.Response + // The channel is closed when the user is done with this response: + // either when Read() errors, or when Close() is called. + reqDone chan<- struct{} + reqDoneClosed bool +} + +var ( + _ Hijacker = &hijackableBody{} + _ HTTPStreamer = &hijackableBody{} +) + +func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody { + return &hijackableBody{ + body: body{ + str: str, + }, + reqDone: done, + conn: conn, + } +} + +func (r *hijackableBody) StreamCreator() StreamCreator { + return r.conn +} + +func (r *hijackableBody) Read(b []byte) (int, error) { + n, err := r.str.Read(b) + if err != nil { + r.requestDone() + } + return n, err +} + +func (r *hijackableBody) requestDone() { + if r.reqDoneClosed || r.reqDone == nil { + return + } + close(r.reqDone) + r.reqDoneClosed = true +} + +func (r *body) StreamID() quic.StreamID { + return r.str.StreamID() +} + +func (r *hijackableBody) Close() error { + r.requestDone() + // If the EOF was read, CancelRead() is a no-op. + r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + return nil +} + +func (r *hijackableBody) HTTPStream() Stream { + return r.str +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/client.go b/vendor/github.com/lucas-clemente/quic-go/http3/client.go new file mode 100644 index 00000000..e4c51688 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/client.go @@ -0,0 +1,423 @@ +package http3 + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "sync" + "time" + + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/marten-seemann/qpack" +) + +// MethodGet0RTT allows a GET request to be sent using 0-RTT. +// Note that 0-RTT data doesn't provide replay protection. +const MethodGet0RTT = "GET_0RTT" + +const ( + defaultUserAgent = "quic-go HTTP/3" + defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB +) + +var defaultQuicConfig = &quic.Config{ + MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams + KeepAlivePeriod: 10 * time.Second, + Versions: []protocol.VersionNumber{protocol.VersionTLS}, +} + +type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + +var dialAddr = quic.DialAddrEarlyContext + +type roundTripperOpts struct { + DisableCompression bool + EnableDatagram bool + MaxHeaderBytes int64 + AdditionalSettings map[uint64]uint64 + StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) +} + +// client is a HTTP3 client doing requests +type client struct { + tlsConf *tls.Config + config *quic.Config + opts *roundTripperOpts + + dialOnce sync.Once + dialer dialFunc + handshakeErr error + + requestWriter *requestWriter + + decoder *qpack.Decoder + + hostname string + conn quic.EarlyConnection + + logger utils.Logger +} + +func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) { + if conf == nil { + conf = defaultQuicConfig.Clone() + } else if len(conf.Versions) == 0 { + conf = conf.Clone() + conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} + } + if len(conf.Versions) != 1 { + return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") + } + if conf.MaxIncomingStreams == 0 { + conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams + } + conf.EnableDatagrams = opts.EnableDatagram + logger := utils.DefaultLogger.WithPrefix("h3 client") + + if tlsConf == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() + } + // Replace existing ALPNs by H3 + tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} + + return &client{ + hostname: authorityAddr("https", hostname), + tlsConf: tlsConf, + requestWriter: newRequestWriter(logger), + decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), + config: conf, + opts: opts, + dialer: dialer, + logger: logger, + }, nil +} + +func (c *client) dial(ctx context.Context) error { + var err error + if c.dialer != nil { + c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) + } else { + c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) + } + if err != nil { + return err + } + + // send the SETTINGs frame, using 0-RTT data, if possible + go func() { + if err := c.setupConn(); err != nil { + c.logger.Debugf("Setting up connection failed: %s", err) + c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") + } + }() + + if c.opts.StreamHijacker != nil { + go c.handleBidirectionalStreams() + } + go c.handleUnidirectionalStreams() + return nil +} + +func (c *client) setupConn() error { + // open the control stream + str, err := c.conn.OpenUniStream() + if err != nil { + return err + } + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) + // send the SETTINGS frame + (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Write(buf) + _, err = str.Write(buf.Bytes()) + return err +} + +func (c *client) handleBidirectionalStreams() { + for { + str, err := c.conn.AcceptStream(context.Background()) + if err != nil { + c.logger.Debugf("accepting bidirectional stream failed: %s", err) + return + } + go func(str quic.Stream) { + _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { + return c.opts.StreamHijacker(ft, c.conn, str, e) + }) + if err == errHijacked { + return + } + if err != nil { + c.logger.Debugf("error handling stream: %s", err) + } + c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + }(str) + } +} + +func (c *client) handleUnidirectionalStreams() { + for { + str, err := c.conn.AcceptUniStream(context.Background()) + if err != nil { + c.logger.Debugf("accepting unidirectional stream failed: %s", err) + return + } + + go func(str quic.ReceiveStream) { + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) + if err != nil { + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) { + return + } + c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) + return + } + // We're only interested in the control stream here. + switch streamType { + case streamTypeControlStream: + case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: + // Our QPACK implementation doesn't use the dynamic table yet. + // TODO: check that only one stream of each type is opened. + return + case streamTypePushStream: + // We never increased the Push ID, so we don't expect any push streams. + c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") + return + default: + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) { + return + } + str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + return + } + f, err := parseNextFrame(str, nil) + if err != nil { + c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + return + } + sf, ok := f.(*settingsFrame) + if !ok { + c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + return + } + if !sf.Datagram { + return + } + // If datagram support was enabled on our side as well as on the server side, + // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. + // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). + if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams { + c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + } + }(str) + } +} + +func (c *client) Close() error { + if c.conn == nil { + return nil + } + return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "") +} + +func (c *client) maxHeaderBytes() uint64 { + if c.opts.MaxHeaderBytes <= 0 { + return defaultMaxResponseHeaderBytes + } + return uint64(c.opts.MaxHeaderBytes) +} + +// RoundTripOpt executes a request and returns a response +func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { + return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) + } + + c.dialOnce.Do(func() { + c.handshakeErr = c.dial(req.Context()) + }) + + if c.handshakeErr != nil { + return nil, c.handshakeErr + } + + // Immediately send out this request, if this is a 0-RTT request. + if req.Method == MethodGet0RTT { + req.Method = http.MethodGet + } else { + // wait for the handshake to complete + select { + case <-c.conn.HandshakeComplete().Done(): + case <-req.Context().Done(): + return nil, req.Context().Err() + } + } + + str, err := c.conn.OpenStreamSync(req.Context()) + if err != nil { + return nil, err + } + + // Request Cancellation: + // This go routine keeps running even after RoundTripOpt() returns. + // It is shut down when the application is done processing the body. + reqDone := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + str.CancelRead(quic.StreamErrorCode(errorRequestCanceled)) + case <-reqDone: + } + }() + + rsp, rerr := c.doRequest(req, str, opt, reqDone) + if rerr.err != nil { // if any error occurred + close(reqDone) + if rerr.streamErr != 0 { // if it was a stream error + str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) + } + if rerr.connErr != 0 { // if it was a connection error + var reason string + if rerr.err != nil { + reason = rerr.err.Error() + } + c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) + } + } + return rsp, rerr.err +} + +func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { + defer body.Close() + b := make([]byte, bodyCopyBufferSize) + for { + n, rerr := body.Read(b) + if n == 0 { + if rerr == nil { + continue + } + if rerr == io.EOF { + break + } + } + if _, err := str.Write(b[:n]); err != nil { + return err + } + if rerr != nil { + if rerr == io.EOF { + break + } + str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) + return rerr + } + } + return nil +} + +func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan struct{}) (*http.Response, requestError) { + var requestGzip bool + if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { + requestGzip = true + } + if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil { + return nil, newStreamError(errorInternalError, err) + } + + if req.Body == nil && !opt.DontCloseRequestStream { + str.Close() + } + + hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + if req.Body != nil { + // send the request body asynchronously + go func() { + if err := c.sendRequestBody(hstr, req.Body); err != nil { + c.logger.Errorf("Error writing request: %s", err) + } + if !opt.DontCloseRequestStream { + hstr.Close() + } + }() + } + + frame, err := parseNextFrame(str, nil) + if err != nil { + return nil, newStreamError(errorFrameError, err) + } + hf, ok := frame.(*headersFrame) + if !ok { + return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) + } + if hf.Length > c.maxHeaderBytes() { + return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) + } + headerBlock := make([]byte, hf.Length) + if _, err := io.ReadFull(str, headerBlock); err != nil { + return nil, newStreamError(errorRequestIncomplete, err) + } + hfs, err := c.decoder.DecodeFull(headerBlock) + if err != nil { + // TODO: use the right error code + return nil, newConnError(errorGeneralProtocolError, err) + } + + connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS) + res := &http.Response{ + Proto: "HTTP/3.0", + ProtoMajor: 3, + Header: http.Header{}, + TLS: &connState, + } + for _, hf := range hfs { + switch hf.Name { + case ":status": + status, err := strconv.Atoi(hf.Value) + if err != nil { + return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header")) + } + res.StatusCode = status + res.Status = hf.Value + " " + http.StatusText(status) + default: + res.Header.Add(hf.Name, hf.Value) + } + } + respBody := newResponseBody(hstr, c.conn, reqDone) + + // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. + _, hasTransferEncoding := res.Header["Transfer-Encoding"] + isInformational := res.StatusCode >= 100 && res.StatusCode < 200 + isNoContent := res.StatusCode == 204 + isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 + if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { + res.ContentLength = -1 + if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 { + if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { + res.ContentLength = clen64 + } + } + } + + if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = newGzipReader(respBody) + res.Uncompressed = true + } else { + res.Body = respBody + } + + return res, requestError{} +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/error_codes.go b/vendor/github.com/lucas-clemente/quic-go/http3/error_codes.go new file mode 100644 index 00000000..d87eef4a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/error_codes.go @@ -0,0 +1,73 @@ +package http3 + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go" +) + +type errorCode quic.ApplicationErrorCode + +const ( + errorNoError errorCode = 0x100 + errorGeneralProtocolError errorCode = 0x101 + errorInternalError errorCode = 0x102 + errorStreamCreationError errorCode = 0x103 + errorClosedCriticalStream errorCode = 0x104 + errorFrameUnexpected errorCode = 0x105 + errorFrameError errorCode = 0x106 + errorExcessiveLoad errorCode = 0x107 + errorIDError errorCode = 0x108 + errorSettingsError errorCode = 0x109 + errorMissingSettings errorCode = 0x10a + errorRequestRejected errorCode = 0x10b + errorRequestCanceled errorCode = 0x10c + errorRequestIncomplete errorCode = 0x10d + errorMessageError errorCode = 0x10e + errorConnectError errorCode = 0x10f + errorVersionFallback errorCode = 0x110 + errorDatagramError errorCode = 0x4a1268 +) + +func (e errorCode) String() string { + switch e { + case errorNoError: + return "H3_NO_ERROR" + case errorGeneralProtocolError: + return "H3_GENERAL_PROTOCOL_ERROR" + case errorInternalError: + return "H3_INTERNAL_ERROR" + case errorStreamCreationError: + return "H3_STREAM_CREATION_ERROR" + case errorClosedCriticalStream: + return "H3_CLOSED_CRITICAL_STREAM" + case errorFrameUnexpected: + return "H3_FRAME_UNEXPECTED" + case errorFrameError: + return "H3_FRAME_ERROR" + case errorExcessiveLoad: + return "H3_EXCESSIVE_LOAD" + case errorIDError: + return "H3_ID_ERROR" + case errorSettingsError: + return "H3_SETTINGS_ERROR" + case errorMissingSettings: + return "H3_MISSING_SETTINGS" + case errorRequestRejected: + return "H3_REQUEST_REJECTED" + case errorRequestCanceled: + return "H3_REQUEST_CANCELLED" + case errorRequestIncomplete: + return "H3_INCOMPLETE_REQUEST" + case errorMessageError: + return "H3_MESSAGE_ERROR" + case errorConnectError: + return "H3_CONNECT_ERROR" + case errorVersionFallback: + return "H3_VERSION_FALLBACK" + case errorDatagramError: + return "H3_DATAGRAM_ERROR" + default: + return fmt.Sprintf("unknown error code: %#x", uint16(e)) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/frames.go b/vendor/github.com/lucas-clemente/quic-go/http3/frames.go new file mode 100644 index 00000000..474aab48 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/frames.go @@ -0,0 +1,164 @@ +package http3 + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// FrameType is the frame type of a HTTP/3 frame +type FrameType uint64 + +type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error) + +type frame interface{} + +var errHijacked = errors.New("hijacked") + +func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) { + qr := quicvarint.NewReader(r) + for { + t, err := quicvarint.Read(qr) + if err != nil { + if unknownFrameHandler != nil { + hijacked, err := unknownFrameHandler(0, err) + if err != nil { + return nil, err + } + if hijacked { + return nil, errHijacked + } + } + return nil, err + } + // Call the unknownFrameHandler for frames not defined in the HTTP/3 spec + if t > 0xd && unknownFrameHandler != nil { + hijacked, err := unknownFrameHandler(FrameType(t), nil) + if err != nil { + return nil, err + } + if hijacked { + return nil, errHijacked + } + // If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it. + } + l, err := quicvarint.Read(qr) + if err != nil { + return nil, err + } + + switch t { + case 0x0: + return &dataFrame{Length: l}, nil + case 0x1: + return &headersFrame{Length: l}, nil + case 0x4: + return parseSettingsFrame(r, l) + case 0x3: // CANCEL_PUSH + case 0x5: // PUSH_PROMISE + case 0x7: // GOAWAY + case 0xd: // MAX_PUSH_ID + } + // skip over unknown frames + if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil { + return nil, err + } + } +} + +type dataFrame struct { + Length uint64 +} + +func (f *dataFrame) Write(b *bytes.Buffer) { + quicvarint.Write(b, 0x0) + quicvarint.Write(b, f.Length) +} + +type headersFrame struct { + Length uint64 +} + +func (f *headersFrame) Write(b *bytes.Buffer) { + quicvarint.Write(b, 0x1) + quicvarint.Write(b, f.Length) +} + +const settingDatagram = 0xffd277 + +type settingsFrame struct { + Datagram bool + Other map[uint64]uint64 // all settings that we don't explicitly recognize +} + +func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { + if l > 8*(1<<10) { + return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l) + } + buf := make([]byte, l) + if _, err := io.ReadFull(r, buf); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + frame := &settingsFrame{} + b := bytes.NewReader(buf) + var readDatagram bool + for b.Len() > 0 { + id, err := quicvarint.Read(b) + if err != nil { // should not happen. We allocated the whole frame already. + return nil, err + } + val, err := quicvarint.Read(b) + if err != nil { // should not happen. We allocated the whole frame already. + return nil, err + } + + switch id { + case settingDatagram: + if readDatagram { + return nil, fmt.Errorf("duplicate setting: %d", id) + } + readDatagram = true + if val != 0 && val != 1 { + return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val) + } + frame.Datagram = val == 1 + default: + if _, ok := frame.Other[id]; ok { + return nil, fmt.Errorf("duplicate setting: %d", id) + } + if frame.Other == nil { + frame.Other = make(map[uint64]uint64) + } + frame.Other[id] = val + } + } + return frame, nil +} + +func (f *settingsFrame) Write(b *bytes.Buffer) { + quicvarint.Write(b, 0x4) + var l protocol.ByteCount + for id, val := range f.Other { + l += quicvarint.Len(id) + quicvarint.Len(val) + } + if f.Datagram { + l += quicvarint.Len(settingDatagram) + quicvarint.Len(1) + } + quicvarint.Write(b, uint64(l)) + if f.Datagram { + quicvarint.Write(b, settingDatagram) + quicvarint.Write(b, 1) + } + for id, val := range f.Other { + quicvarint.Write(b, id) + quicvarint.Write(b, val) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/gzip_reader.go b/vendor/github.com/lucas-clemente/quic-go/http3/gzip_reader.go new file mode 100644 index 00000000..01983ac7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/gzip_reader.go @@ -0,0 +1,39 @@ +package http3 + +// copied from net/transport.go + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +import ( + "compress/gzip" + "io" +) + +// call gzip.NewReader on the first call to Read +type gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // sticky error +} + +func newGzipReader(body io.ReadCloser) io.ReadCloser { + return &gzipReader{body: body} +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zerr != nil { + return 0, gz.zerr + } + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + gz.zerr = err + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/http_stream.go b/vendor/github.com/lucas-clemente/quic-go/http3/http_stream.go new file mode 100644 index 00000000..4c69068c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/http_stream.go @@ -0,0 +1,71 @@ +package http3 + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go" +) + +// A Stream is a HTTP/3 stream. +// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. +type Stream quic.Stream + +// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly +// from the QUIC stream, it writes to and reads from the HTTP stream. +type stream struct { + quic.Stream + + onFrameError func() + bytesRemainingInFrame uint64 +} + +var _ Stream = &stream{} + +func newStream(str quic.Stream, onFrameError func()) *stream { + return &stream{Stream: str, onFrameError: onFrameError} +} + +func (s *stream) Read(b []byte) (int, error) { + if s.bytesRemainingInFrame == 0 { + parseLoop: + for { + frame, err := parseNextFrame(s.Stream, nil) + if err != nil { + return 0, err + } + switch f := frame.(type) { + case *headersFrame: + // skip HEADERS frames + continue + case *dataFrame: + s.bytesRemainingInFrame = f.Length + break parseLoop + default: + s.onFrameError() + // parseNextFrame skips over unknown frame types + // Therefore, this condition is only entered when we parsed another known frame type. + return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) + } + } + } + + var n int + var err error + if s.bytesRemainingInFrame < uint64(len(b)) { + n, err = s.Stream.Read(b[:s.bytesRemainingInFrame]) + } else { + n, err = s.Stream.Read(b) + } + s.bytesRemainingInFrame -= uint64(n) + return n, err +} + +func (s *stream) Write(b []byte) (int, error) { + buf := &bytes.Buffer{} + (&dataFrame{Length: uint64(len(b))}).Write(buf) + if _, err := s.Stream.Write(buf.Bytes()); err != nil { + return 0, err + } + return s.Stream.Write(b) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/request.go b/vendor/github.com/lucas-clemente/quic-go/http3/request.go new file mode 100644 index 00000000..0b9a7278 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/request.go @@ -0,0 +1,113 @@ +package http3 + +import ( + "crypto/tls" + "errors" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/marten-seemann/qpack" +) + +func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { + var path, authority, method, protocol, scheme, contentLengthStr string + + httpHeaders := http.Header{} + for _, h := range headers { + switch h.Name { + case ":path": + path = h.Value + case ":method": + method = h.Value + case ":authority": + authority = h.Value + case ":protocol": + protocol = h.Value + case ":scheme": + scheme = h.Value + case "content-length": + contentLengthStr = h.Value + default: + if !h.IsPseudo() { + httpHeaders.Add(h.Name, h.Value) + } + } + } + + // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 + if len(httpHeaders["Cookie"]) > 0 { + httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; ")) + } + + isConnect := method == http.MethodConnect + // Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4 + isExtendedConnected := isConnect && protocol != "" + if isExtendedConnected { + if scheme == "" || path == "" || authority == "" { + return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty") + } + } else if isConnect { + if path != "" || authority == "" { // normal CONNECT + return nil, errors.New(":path must be empty and :authority must not be empty") + } + } else if len(path) == 0 || len(authority) == 0 || len(method) == 0 { + return nil, errors.New(":path, :authority and :method must not be empty") + } + + var u *url.URL + var requestURI string + var err error + + if isConnect { + u = &url.URL{} + if isExtendedConnected { + u, err = url.ParseRequestURI(path) + if err != nil { + return nil, err + } + } else { + u.Path = path + } + u.Scheme = scheme + u.Host = authority + requestURI = authority + } else { + protocol = "HTTP/3.0" + u, err = url.ParseRequestURI(path) + if err != nil { + return nil, err + } + requestURI = path + } + + var contentLength int64 + if len(contentLengthStr) > 0 { + contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) + if err != nil { + return nil, err + } + } + + return &http.Request{ + Method: method, + URL: u, + Proto: protocol, + ProtoMajor: 3, + ProtoMinor: 0, + Header: httpHeaders, + Body: nil, + ContentLength: contentLength, + Host: authority, + RequestURI: requestURI, + TLS: &tls.ConnectionState{}, + }, nil +} + +func hostnameFromRequest(req *http.Request) string { + if req.URL != nil { + return req.URL.Host + } + return "" +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/request_writer.go b/vendor/github.com/lucas-clemente/quic-go/http3/request_writer.go new file mode 100644 index 00000000..0a9c67ac --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/request_writer.go @@ -0,0 +1,283 @@ +package http3 + +import ( + "bytes" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/marten-seemann/qpack" + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" + "golang.org/x/net/idna" +) + +const bodyCopyBufferSize = 8 * 1024 + +type requestWriter struct { + mutex sync.Mutex + encoder *qpack.Encoder + headerBuf *bytes.Buffer + + logger utils.Logger +} + +func newRequestWriter(logger utils.Logger) *requestWriter { + headerBuf := &bytes.Buffer{} + encoder := qpack.NewEncoder(headerBuf) + return &requestWriter{ + encoder: encoder, + headerBuf: headerBuf, + logger: logger, + } +} + +func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error { + // TODO: figure out how to add support for trailers + buf := &bytes.Buffer{} + if err := w.writeHeaders(buf, req, gzip); err != nil { + return err + } + _, err := str.Write(buf.Bytes()) + return err +} + +func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error { + w.mutex.Lock() + defer w.mutex.Unlock() + defer w.encoder.Close() + defer w.headerBuf.Reset() + + if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil { + return err + } + + buf := &bytes.Buffer{} + hf := headersFrame{Length: uint64(w.headerBuf.Len())} + hf.Write(buf) + if _, err := wr.Write(buf.Bytes()); err != nil { + return err + } + _, err := wr.Write(w.headerBuf.Bytes()) + return err +} + +// copied from net/transport.go +// Modified to support Extended CONNECT: +// Contrary to what the godoc for the http.Request says, +// we do respect the Proto field if the method is CONNECT. +func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error { + host := req.Host + if host == "" { + host = req.URL.Host + } + host, err := httpguts.PunycodeHostPort(host) + if err != nil { + return err + } + + // http.NewRequest sets this field to HTTP/1.1 + isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" + + var path string + if req.Method != http.MethodConnect || isExtendedConnect { + path = req.URL.RequestURI() + if !validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !validPseudoPath(path) { + if req.URL.Opaque != "" { + return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return fmt.Errorf("invalid request :path %q", orig) + } + } + } + } + + // Check for any invalid headers and return an error before we + // potentially pollute our hpack state. (We want to be able to + // continue to reuse the hpack encoder for future requests) + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Errorf("invalid HTTP header name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + } + } + } + + enumerateHeaders := func(f func(name, value string)) { + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production (see Sections 3.3 and 3.4 of + // [RFC3986]). + f(":authority", host) + f(":method", req.Method) + if req.Method != http.MethodConnect || isExtendedConnect { + f(":path", path) + f(":scheme", req.URL.Scheme) + } + if isExtendedConnect { + f(":protocol", req.Proto) + } + if trailers != "" { + f("trailer", trailers) + } + + var didUA bool + for k, vv := range req.Header { + if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") { + // Host is :authority, already sent. + // Content-Length is automatic, set below. + continue + } else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") || + strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") || + strings.EqualFold(k, "keep-alive") { + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. + continue + } else if strings.EqualFold(k, "user-agent") { + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). + didUA = true + if len(vv) < 1 { + continue + } + vv = vv[:1] + if vv[0] == "" { + continue + } + + } + + for _, v := range vv { + f(k, v) + } + } + if shouldSendReqContentLength(req.Method, contentLength) { + f("content-length", strconv.FormatInt(contentLength, 10)) + } + if addGzipHeader { + f("accept-encoding", "gzip") + } + if !didUA { + f("user-agent", defaultUserAgent) + } + } + + // Do a first pass over the headers counting bytes to ensure + // we don't exceed cc.peerMaxHeaderListSize. This is done as a + // separate pass before encoding the headers to prevent + // modifying the hpack state. + hlSize := uint64(0) + enumerateHeaders(func(name, value string) { + hf := hpack.HeaderField{Name: name, Value: value} + hlSize += uint64(hf.Size()) + }) + + // TODO: check maximum header list size + // if hlSize > cc.peerMaxHeaderListSize { + // return errRequestHeaderListSize + // } + + // trace := httptrace.ContextClientTrace(req.Context()) + // traceHeaders := traceHasWroteHeaderField(trace) + + // Header list size is ok. Write the headers. + enumerateHeaders(func(name, value string) { + name = strings.ToLower(name) + w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value}) + // if traceHeaders { + // traceWroteHeaderField(trace, name, value) + // } + }) + + return nil +} + +// authorityAddr returns a given authority (a host/IP, or host:port / ip:port) +// and returns a host:port. The port 443 is added if needed. +func authorityAddr(scheme string, authority string) (addr string) { + host, port, err := net.SplitHostPort(authority) + if err != nil { // authority didn't have a port + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a + } + // IPv6 address literal, without a port: + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port + } + return net.JoinHostPort(host, port) +} + +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// *) a non-empty string starting with '/' +// *) the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +// +// We used to enforce that the path also didn't start with "//", but +// Google's GFE accepts such paths and Chrome sends them, so ignore +// that part of the spec. See golang.org/issue/19103. +func validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/') || v == "*" +} + +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + +// shouldSendReqContentLength reports whether the http2.Transport should send +// a "content-length" request header. This logic is basically a copy of the net/http +// transferWriter.shouldSendContentLength. +// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). +// -1 means unknown. +func shouldSendReqContentLength(method string, contentLength int64) bool { + if contentLength > 0 { + return true + } + if contentLength < 0 { + return false + } + // For zero bodies, whether we send a content-length depends on the method. + // It also kinda doesn't matter for http2 either way, with END_STREAM. + switch method { + case "POST", "PUT", "PATCH": + return true + default: + return false + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/response_writer.go b/vendor/github.com/lucas-clemente/quic-go/http3/response_writer.go new file mode 100644 index 00000000..70a7cd3f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/response_writer.go @@ -0,0 +1,118 @@ +package http3 + +import ( + "bufio" + "bytes" + "net/http" + "strconv" + "strings" + + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/marten-seemann/qpack" +) + +type responseWriter struct { + conn quic.Connection + bufferedStr *bufio.Writer + + header http.Header + status int // status code passed to WriteHeader + headerWritten bool + + logger utils.Logger +} + +var ( + _ http.ResponseWriter = &responseWriter{} + _ http.Flusher = &responseWriter{} + _ Hijacker = &responseWriter{} +) + +func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { + return &responseWriter{ + header: http.Header{}, + conn: conn, + bufferedStr: bufio.NewWriter(str), + logger: logger, + } +} + +func (w *responseWriter) Header() http.Header { + return w.header +} + +func (w *responseWriter) WriteHeader(status int) { + if w.headerWritten { + return + } + + if status < 100 || status >= 200 { + w.headerWritten = true + } + w.status = status + + var headers bytes.Buffer + enc := qpack.NewEncoder(&headers) + enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) + + for k, v := range w.header { + for index := range v { + enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) + } + } + + buf := &bytes.Buffer{} + (&headersFrame{Length: uint64(headers.Len())}).Write(buf) + w.logger.Infof("Responding with %d", status) + if _, err := w.bufferedStr.Write(buf.Bytes()); err != nil { + w.logger.Errorf("could not write headers frame: %s", err.Error()) + } + if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil { + w.logger.Errorf("could not write header frame payload: %s", err.Error()) + } + if !w.headerWritten { + w.Flush() + } +} + +func (w *responseWriter) Write(p []byte) (int, error) { + if !w.headerWritten { + w.WriteHeader(200) + } + if !bodyAllowedForStatus(w.status) { + return 0, http.ErrBodyNotAllowed + } + df := &dataFrame{Length: uint64(len(p))} + buf := &bytes.Buffer{} + df.Write(buf) + if _, err := w.bufferedStr.Write(buf.Bytes()); err != nil { + return 0, err + } + return w.bufferedStr.Write(p) +} + +func (w *responseWriter) Flush() { + if err := w.bufferedStr.Flush(); err != nil { + w.logger.Errorf("could not flush to stream: %s", err.Error()) + } +} + +func (w *responseWriter) StreamCreator() StreamCreator { + return w.conn +} + +// copied from http2/http2.go +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 2616, section 4.4. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/roundtrip.go b/vendor/github.com/lucas-clemente/quic-go/http3/roundtrip.go new file mode 100644 index 00000000..a4d0a312 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/roundtrip.go @@ -0,0 +1,221 @@ +package http3 + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + + "github.com/lucas-clemente/quic-go" + + "golang.org/x/net/http/httpguts" +) + +type roundTripCloser interface { + RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) + io.Closer +} + +// RoundTripper implements the http.RoundTripper interface +type RoundTripper struct { + mutex sync.Mutex + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // QuicConfig is the quic.Config used for dialing new connections. + // If nil, reasonable default values will be used. + QuicConfig *quic.Config + + // Enable support for HTTP/3 datagrams. + // If set to true, QuicConfig.EnableDatagram will be set. + // See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html. + EnableDatagrams bool + + // Additional HTTP/3 settings. + // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. + AdditionalSettings map[uint64]uint64 + + // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. + // It is called right after parsing the frame type. + // If parsing the frame type fails, the error is passed to the callback. + // In that case, the frame type will not be set. + // Callers can either ignore the frame and return control of the stream back to HTTP/3 + // (by returning hijacked false). + // Alternatively, callers can take over the QUIC stream (by returning hijacked true). + StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) + + // When set, this callback is called for unknown unidirectional stream of unknown stream type. + // If parsing the stream type fails, the error is passed to the callback. + // In that case, the stream type will not be set. + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + + // Dial specifies an optional dial function for creating QUIC + // connections for requests. + // If Dial is nil, quic.DialAddrEarlyContext will be used. + Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + + // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // allowed in the server's response header. + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 + + clients map[string]roundTripCloser +} + +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type RoundTripOpt struct { + // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. + // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. + OnlyCachedConn bool + // DontCloseRequestStream controls whether the request stream is closed after sending the request. + DontCloseRequestStream bool +} + +var ( + _ http.RoundTripper = &RoundTripper{} + _ io.Closer = &RoundTripper{} +) + +// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set +var ErrNoCachedConn = errors.New("http3: no cached connection was available") + +// RoundTripOpt is like RoundTrip, but takes options. +func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + if req.URL == nil { + closeRequestBody(req) + return nil, errors.New("http3: nil Request.URL") + } + if req.URL.Host == "" { + closeRequestBody(req) + return nil, errors.New("http3: no Host in request URL") + } + if req.Header == nil { + closeRequestBody(req) + return nil, errors.New("http3: nil Request.Header") + } + + if req.URL.Scheme == "https" { + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("http3: invalid http header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) + } + } + } + } else { + closeRequestBody(req) + return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) + } + + if req.Method != "" && !validMethod(req.Method) { + closeRequestBody(req) + return nil, fmt.Errorf("http3: invalid method %q", req.Method) + } + + hostname := authorityAddr("https", hostnameFromRequest(req)) + cl, err := r.getClient(hostname, opt.OnlyCachedConn) + if err != nil { + return nil, err + } + return cl.RoundTripOpt(req, opt) +} + +// RoundTrip does a round trip. +func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return r.RoundTripOpt(req, RoundTripOpt{}) +} + +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.clients == nil { + r.clients = make(map[string]roundTripCloser) + } + + client, ok := r.clients[hostname] + if !ok { + if onlyCached { + return nil, ErrNoCachedConn + } + var err error + client, err = newClient( + hostname, + r.TLSClientConfig, + &roundTripperOpts{ + EnableDatagram: r.EnableDatagrams, + DisableCompression: r.DisableCompression, + MaxHeaderBytes: r.MaxResponseHeaderBytes, + StreamHijacker: r.StreamHijacker, + UniStreamHijacker: r.UniStreamHijacker, + }, + r.QuicConfig, + r.Dial, + ) + if err != nil { + return nil, err + } + r.clients[hostname] = client + } + return client, nil +} + +// Close closes the QUIC connections that this RoundTripper has used +func (r *RoundTripper) Close() error { + r.mutex.Lock() + defer r.mutex.Unlock() + for _, client := range r.clients { + if err := client.Close(); err != nil { + return err + } + } + r.clients = nil + return nil +} + +func closeRequestBody(req *http.Request) { + if req.Body != nil { + req.Body.Close() + } +} + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +// copied from net/http/http.go +func isNotToken(r rune) bool { + return !httpguts.IsTokenRune(r) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/http3/server.go b/vendor/github.com/lucas-clemente/quic-go/http3/server.go new file mode 100644 index 00000000..0432a615 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/http3/server.go @@ -0,0 +1,742 @@ +package http3 + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "runtime" + "strings" + "sync" + "time" + + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/marten-seemann/qpack" +) + +// allows mocking of quic.Listen and quic.ListenAddr +var ( + quicListen = quic.ListenEarly + quicListenAddr = quic.ListenAddrEarly +) + +const ( + nextProtoH3Draft29 = "h3-29" + nextProtoH3 = "h3" +) + +// StreamType is the stream type of a unidirectional stream. +type StreamType uint64 + +const ( + streamTypeControlStream = 0 + streamTypePushStream = 1 + streamTypeQPACKEncoderStream = 2 + streamTypeQPACKDecoderStream = 3 +) + +func versionToALPN(v protocol.VersionNumber) string { + if v == protocol.Version1 || v == protocol.Version2 { + return nextProtoH3 + } + if v == protocol.VersionTLS || v == protocol.VersionDraft29 { + return nextProtoH3Draft29 + } + return "" +} + +// ConfigureTLSConfig creates a new tls.Config which can be used +// to create a quic.Listener meant for serving http3. The created +// tls.Config adds the functionality of detecting the used QUIC version +// in order to set the correct ALPN value for the http3 connection. +func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { + // The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set. + // That way, we can get the QUIC version and set the correct ALPN value. + return &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + // determine the ALPN from the QUIC version used + proto := nextProtoH3 + if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { + proto = versionToALPN(qconn.GetQUICVersion()) + } + config := tlsConf + if tlsConf.GetConfigForClient != nil { + getConfigForClient := tlsConf.GetConfigForClient + var err error + conf, err := getConfigForClient(ch) + if err != nil { + return nil, err + } + if conf != nil { + config = conf + } + } + if config == nil { + return nil, nil + } + config = config.Clone() + config.NextProtos = []string{proto} + return config, nil + }, + } +} + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name } + +// ServerContextKey is a context key. It can be used in HTTP +// handlers with Context.Value to access the server that +// started the handler. The associated value will be of +// type *http3.Server. +var ServerContextKey = &contextKey{"http3-server"} + +type requestError struct { + err error + streamErr errorCode + connErr errorCode +} + +func newStreamError(code errorCode, err error) requestError { + return requestError{err: err, streamErr: code} +} + +func newConnError(code errorCode, err error) requestError { + return requestError{err: err, connErr: code} +} + +// listenerInfo contains info about specific listener added with addListener +type listenerInfo struct { + port int // 0 means that no info about port is available +} + +// Server is a HTTP/3 server. +type Server struct { + // Addr optionally specifies the UDP address for the server to listen on, + // in the form "host:port". + // + // When used by ListenAndServe and ListenAndServeTLS methods, if empty, + // ":https" (port 443) is used. See net.Dial for details of the address + // format. + // + // Otherwise, if Port is not set and underlying QUIC listeners do not + // have valid port numbers, the port part is used in Alt-Svc headers set + // with SetQuicHeaders. + Addr string + + // Port is used in Alt-Svc response headers set with SetQuicHeaders. If + // needed Port can be manually set when the Server is created. + // + // This is useful when a Layer 4 firewall is redirecting UDP traffic and + // clients must use a port different from the port the Server is + // listening on. + Port int + + // TLSConfig provides a TLS configuration for use by server. It must be + // set for ListenAndServe and Serve methods. + TLSConfig *tls.Config + + // QuicConfig provides the parameters for QUIC connection created with + // Serve. If nil, it uses reasonable default values. + // + // Configured versions are also used in Alt-Svc response header set with + // SetQuicHeaders. + QuicConfig *quic.Config + + // Handler is the HTTP request handler to use. If not set, defaults to + // http.NotFound. + Handler http.Handler + + // EnableDatagrams enables support for HTTP/3 datagrams. + // If set to true, QuicConfig.EnableDatagram will be set. + // See https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-07. + EnableDatagrams bool + + // MaxHeaderBytes controls the maximum number of bytes the server will + // read parsing the request HEADERS frame. It does not limit the size of + // the request body. If zero or negative, http.DefaultMaxHeaderBytes is + // used. + MaxHeaderBytes int + + // AdditionalSettings specifies additional HTTP/3 settings. + // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. + AdditionalSettings map[uint64]uint64 + + // StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream. + // It is called right after parsing the frame type. + // If parsing the frame type fails, the error is passed to the callback. + // In that case, the frame type will not be set. + // Callers can either ignore the frame and return control of the stream back to HTTP/3 + // (by returning hijacked false). + // Alternatively, callers can take over the QUIC stream (by returning hijacked true). + StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) + + // UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type. + // If parsing the stream type fails, the error is passed to the callback. + // In that case, the stream type will not be set. + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + + mutex sync.RWMutex + listeners map[*quic.EarlyListener]listenerInfo + + closed bool + + altSvcHeader string + + logger utils.Logger +} + +// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. +// +// If s.Addr is blank, ":https" is used. +func (s *Server) ListenAndServe() error { + return s.serveConn(s.TLSConfig, nil) +} + +// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. +// +// If s.Addr is blank, ":https" is used. +func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { + var err error + certs := make([]tls.Certificate, 1) + certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + // We currently only use the cert-related stuff from tls.Config, + // so we don't need to make a full copy. + config := &tls.Config{ + Certificates: certs, + } + return s.serveConn(config, nil) +} + +// Serve an existing UDP connection. +// It is possible to reuse the same connection for outgoing connections. +// Closing the server does not close the packet conn. +func (s *Server) Serve(conn net.PacketConn) error { + return s.serveConn(s.TLSConfig, conn) +} + +// ServeListener serves an existing QUIC listener. +// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config +// and use it to construct a http3-friendly QUIC listener. +// Closing the server does close the listener. +func (s *Server) ServeListener(ln quic.EarlyListener) error { + if err := s.addListener(&ln); err != nil { + return err + } + err := s.serveListener(ln) + s.removeListener(&ln) + return err +} + +var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig") + +func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { + if tlsConf == nil { + return errServerWithoutTLSConfig + } + + s.mutex.Lock() + closed := s.closed + s.mutex.Unlock() + if closed { + return http.ErrServerClosed + } + + baseConf := ConfigureTLSConfig(tlsConf) + quicConf := s.QuicConfig + if quicConf == nil { + quicConf = &quic.Config{} + } else { + quicConf = s.QuicConfig.Clone() + } + if s.EnableDatagrams { + quicConf.EnableDatagrams = true + } + + var ln quic.EarlyListener + var err error + if conn == nil { + addr := s.Addr + if addr == "" { + addr = ":https" + } + ln, err = quicListenAddr(addr, baseConf, quicConf) + } else { + ln, err = quicListen(conn, baseConf, quicConf) + } + if err != nil { + return err + } + if err := s.addListener(&ln); err != nil { + return err + } + err = s.serveListener(ln) + s.removeListener(&ln) + return err +} + +func (s *Server) serveListener(ln quic.EarlyListener) error { + for { + conn, err := ln.Accept(context.Background()) + if err != nil { + return err + } + go s.handleConn(conn) + } +} + +func extractPort(addr string) (int, error) { + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return 0, err + } + + portInt, err := net.LookupPort("tcp", portStr) + if err != nil { + return 0, err + } + return portInt, nil +} + +func (s *Server) generateAltSvcHeader() { + if len(s.listeners) == 0 { + // Don't announce any ports since no one is listening for connections + s.altSvcHeader = "" + return + } + + // This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed. + supportedVersions := protocol.SupportedVersions + if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 { + supportedVersions = s.QuicConfig.Versions + } + + // keep track of which have been seen so we don't yield duplicate values + seen := make(map[string]struct{}, len(supportedVersions)) + var versionStrings []string + for _, version := range supportedVersions { + if v := versionToALPN(version); len(v) > 0 { + if _, ok := seen[v]; !ok { + versionStrings = append(versionStrings, v) + seen[v] = struct{}{} + } + } + } + + var altSvc []string + addPort := func(port int) { + for _, v := range versionStrings { + altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port)) + } + } + + if s.Port != 0 { + // if Port is specified, we must use it instead of the + // listener addresses since there's a reason it's specified. + addPort(s.Port) + } else { + // if we have some listeners assigned, try to find ports + // which we can announce, otherwise nothing should be announced + validPortsFound := false + for _, info := range s.listeners { + if info.port != 0 { + addPort(info.port) + validPortsFound = true + } + } + if !validPortsFound { + if port, err := extractPort(s.Addr); err == nil { + addPort(port) + } + } + } + + s.altSvcHeader = strings.Join(altSvc, ",") +} + +// We store a pointer to interface in the map set. This is safe because we only +// call trackListener via Serve and can track+defer untrack the same pointer to +// local variable there. We never need to compare a Listener from another caller. +func (s *Server) addListener(l *quic.EarlyListener) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closed { + return http.ErrServerClosed + } + if s.logger == nil { + s.logger = utils.DefaultLogger.WithPrefix("server") + } + if s.listeners == nil { + s.listeners = make(map[*quic.EarlyListener]listenerInfo) + } + + if port, err := extractPort((*l).Addr().String()); err == nil { + s.listeners[l] = listenerInfo{port} + } else { + s.logger.Errorf( + "Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) + s.listeners[l] = listenerInfo{} + } + s.generateAltSvcHeader() + return nil +} + +func (s *Server) removeListener(l *quic.EarlyListener) { + s.mutex.Lock() + delete(s.listeners, l) + s.generateAltSvcHeader() + s.mutex.Unlock() +} + +func (s *Server) handleConn(conn quic.EarlyConnection) { + decoder := qpack.NewDecoder(nil) + + // send a SETTINGS frame + str, err := conn.OpenUniStream() + if err != nil { + s.logger.Debugf("Opening the control stream failed.") + return + } + buf := &bytes.Buffer{} + quicvarint.Write(buf, streamTypeControlStream) // stream type + (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Write(buf) + str.Write(buf.Bytes()) + + go s.handleUnidirectionalStreams(conn) + + // Process all requests immediately. + // It's the client's responsibility to decide which requests are eligible for 0-RTT. + for { + str, err := conn.AcceptStream(context.Background()) + if err != nil { + s.logger.Debugf("Accepting stream failed: %s", err) + return + } + go func() { + rerr := s.handleRequest(conn, str, decoder, func() { + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") + }) + if rerr.err == errHijacked { + return + } + if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { + s.logger.Debugf("Handling request failed: %s", err) + if rerr.streamErr != 0 { + str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) + } + if rerr.connErr != 0 { + var reason string + if rerr.err != nil { + reason = rerr.err.Error() + } + conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) + } + return + } + str.Close() + }() + } +} + +func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) { + for { + str, err := conn.AcceptUniStream(context.Background()) + if err != nil { + s.logger.Debugf("accepting unidirectional stream failed: %s", err) + return + } + + go func(str quic.ReceiveStream) { + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) + if err != nil { + if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) { + return + } + s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) + return + } + // We're only interested in the control stream here. + switch streamType { + case streamTypeControlStream: + case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: + // Our QPACK implementation doesn't use the dynamic table yet. + // TODO: check that only one stream of each type is opened. + return + case streamTypePushStream: // only the server can push + conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "") + return + default: + if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) { + return + } + str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + return + } + f, err := parseNextFrame(str, nil) + if err != nil { + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + return + } + sf, ok := f.(*settingsFrame) + if !ok { + conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + return + } + if !sf.Datagram { + return + } + // If datagram support was enabled on our side as well as on the client side, + // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. + // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). + if s.EnableDatagrams && !conn.ConnectionState().SupportsDatagrams { + conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + } + }(str) + } +} + +func (s *Server) maxHeaderBytes() uint64 { + if s.MaxHeaderBytes <= 0 { + return http.DefaultMaxHeaderBytes + } + return uint64(s.MaxHeaderBytes) +} + +func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { + var ufh unknownFrameHandlerFunc + if s.StreamHijacker != nil { + ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) } + } + frame, err := parseNextFrame(str, ufh) + if err != nil { + if err == errHijacked { + return requestError{err: errHijacked} + } + return newStreamError(errorRequestIncomplete, err) + } + hf, ok := frame.(*headersFrame) + if !ok { + return newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) + } + if hf.Length > s.maxHeaderBytes() { + return newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes())) + } + headerBlock := make([]byte, hf.Length) + if _, err := io.ReadFull(str, headerBlock); err != nil { + return newStreamError(errorRequestIncomplete, err) + } + hfs, err := decoder.DecodeFull(headerBlock) + if err != nil { + // TODO: use the right error code + return newConnError(errorGeneralProtocolError, err) + } + req, err := requestFromHeaders(hfs) + if err != nil { + // TODO: use the right error code + return newStreamError(errorGeneralProtocolError, err) + } + + req.RemoteAddr = conn.RemoteAddr().String() + body := newRequestBody(newStream(str, onFrameError)) + req.Body = body + + if s.logger.Debug() { + s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID()) + } else { + s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI) + } + + ctx := str.Context() + ctx = context.WithValue(ctx, ServerContextKey, s) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) + req = req.WithContext(ctx) + r := newResponseWriter(str, conn, s.logger) + defer r.Flush() + handler := s.Handler + if handler == nil { + handler = http.DefaultServeMux + } + + var panicked bool + func() { + defer func() { + if p := recover(); p != nil { + // Copied from net/http/server.go + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + s.logger.Errorf("http: panic serving: %v\n%s", p, buf) + panicked = true + } + }() + handler.ServeHTTP(r, req) + }() + + if body.wasStreamHijacked() { + return requestError{err: errHijacked} + } + + if panicked { + r.WriteHeader(500) + } else { + r.WriteHeader(200) + } + // If the EOF was read by the handler, CancelRead() is a no-op. + str.CancelRead(quic.StreamErrorCode(errorNoError)) + return requestError{} +} + +// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. +// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. +func (s *Server) Close() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.closed = true + + var err error + for ln := range s.listeners { + if cerr := (*ln).Close(); cerr != nil && err == nil { + err = cerr + } + } + return err +} + +// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete. +// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. +func (s *Server) CloseGracefully(timeout time.Duration) error { + // TODO: implement + return nil +} + +// ErrNoAltSvcPort is the error returned by SetQuicHeaders when no port was found +// for Alt-Svc to announce. This can happen if listening on a PacketConn without a port +// (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr. +var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr") + +// SetQuicHeaders can be used to set the proper headers that announce that this server supports HTTP/3. +// The values set by default advertise all of the ports the server is listening on, but can be +// changed to a specific port by setting Server.Port before launching the serverr. +// If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used +// to extract the port, if specified. +// For example, a server launched using ListenAndServe on an address with port 443 would set: +// Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +func (s *Server) SetQuicHeaders(hdr http.Header) error { + s.mutex.RLock() + defer s.mutex.RUnlock() + + if s.altSvcHeader == "" { + return ErrNoAltSvcPort + } + // use the map directly to avoid constant canonicalization + // since the key is already canonicalized + hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader) + return nil +} + +// ListenAndServeQUIC listens on the UDP network address addr and calls the +// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is +// used when handler is nil. +func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error { + server := &Server{ + Addr: addr, + Handler: handler, + } + return server.ListenAndServeTLS(certFile, keyFile) +} + +// ListenAndServe listens on the given network address for both, TLS and QUIC +// connections in parallel. It returns if one of the two returns an error. +// http.DefaultServeMux is used when handler is nil. +// The correct Alt-Svc headers for QUIC are set. +func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error { + // Load certs + var err error + certs := make([]tls.Certificate, 1) + certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + // We currently only use the cert-related stuff from tls.Config, + // so we don't need to make a full copy. + config := &tls.Config{ + Certificates: certs, + } + + if addr == "" { + addr = ":https" + } + + // Open the listeners + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err + } + defer udpConn.Close() + + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return err + } + tcpConn, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return err + } + defer tcpConn.Close() + + tlsConn := tls.NewListener(tcpConn, config) + defer tlsConn.Close() + + // Start the servers + httpServer := &http.Server{} + quicServer := &Server{ + TLSConfig: config, + } + + if handler == nil { + handler = http.DefaultServeMux + } + httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + quicServer.SetQuicHeaders(w.Header()) + handler.ServeHTTP(w, r) + }) + + hErr := make(chan error) + qErr := make(chan error) + go func() { + hErr <- httpServer.Serve(tlsConn) + }() + go func() { + qErr <- quicServer.Serve(udpConn) + }() + + select { + case err := <-hErr: + quicServer.Close() + return err + case err := <-qErr: + // Cannot close the HTTP server or wait for requests to complete properly :/ + return err + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/interface.go b/vendor/github.com/lucas-clemente/quic-go/interface.go new file mode 100644 index 00000000..6130b549 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/interface.go @@ -0,0 +1,328 @@ +package quic + +import ( + "context" + "errors" + "io" + "net" + "time" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/logging" +) + +// The StreamID is the ID of a QUIC stream. +type StreamID = protocol.StreamID + +// A VersionNumber is a QUIC version number. +type VersionNumber = protocol.VersionNumber + +const ( + // VersionDraft29 is IETF QUIC draft-29 + VersionDraft29 = protocol.VersionDraft29 + // Version1 is RFC 9000 + Version1 = protocol.Version1 + Version2 = protocol.Version2 +) + +// A Token can be used to verify the ownership of the client address. +type Token struct { + // IsRetryToken encodes how the client received the token. There are two ways: + // * In a Retry packet sent when trying to establish a new connection. + // * In a NEW_TOKEN frame on a previous connection. + IsRetryToken bool + RemoteAddr string + SentTime time.Time +} + +// A ClientToken is a token received by the client. +// It can be used to skip address validation on future connection attempts. +type ClientToken struct { + data []byte +} + +type TokenStore interface { + // Pop searches for a ClientToken associated with the given key. + // Since tokens are not supposed to be reused, it must remove the token from the cache. + // It returns nil when no token is found. + Pop(key string) (token *ClientToken) + + // Put adds a token to the cache with the given key. It might get called + // multiple times in a connection. + Put(key string, token *ClientToken) +} + +// Err0RTTRejected is the returned from: +// * Open{Uni}Stream{Sync} +// * Accept{Uni}Stream +// * Stream.Read and Stream.Write +// when the server rejects a 0-RTT connection attempt. +var Err0RTTRejected = errors.New("0-RTT rejected") + +// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection. +// It is set on the Connection.Context() context, +// as well as on the context passed to logging.Tracer.NewConnectionTracer. +var ConnectionTracingKey = connTracingCtxKey{} + +type connTracingCtxKey struct{} + +// Stream is the interface implemented by QUIC streams +// In addition to the errors listed on the Connection, +// calls to stream functions can return a StreamError if the stream is canceled. +type Stream interface { + ReceiveStream + SendStream + // SetDeadline sets the read and write deadlines associated + // with the connection. It is equivalent to calling both + // SetReadDeadline and SetWriteDeadline. + SetDeadline(t time.Time) error +} + +// A ReceiveStream is a unidirectional Receive Stream. +type ReceiveStream interface { + // StreamID returns the stream ID. + StreamID() StreamID + // Read reads data from the stream. + // Read can be made to time out and return a net.Error with Timeout() == true + // after a fixed time limit; see SetDeadline and SetReadDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. + // If the connection was closed due to a timeout, the error satisfies + // the net.Error interface, and Timeout() will be true. + io.Reader + // CancelRead aborts receiving on this stream. + // It will ask the peer to stop transmitting stream data. + // Read will unblock immediately, and future Read calls will fail. + // When called multiple times or after reading the io.EOF it is a no-op. + CancelRead(StreamErrorCode) + // SetReadDeadline sets the deadline for future Read calls and + // any currently-blocked Read call. + // A zero value for t means Read will not time out. + + SetReadDeadline(t time.Time) error +} + +// A SendStream is a unidirectional Send Stream. +type SendStream interface { + // StreamID returns the stream ID. + StreamID() StreamID + // Write writes data to the stream. + // Write can be made to time out and return a net.Error with Timeout() == true + // after a fixed time limit; see SetDeadline and SetWriteDeadline. + // If the stream was canceled by the peer, the error implements the StreamError + // interface, and Canceled() == true. + // If the connection was closed due to a timeout, the error satisfies + // the net.Error interface, and Timeout() will be true. + io.Writer + // Close closes the write-direction of the stream. + // Future calls to Write are not permitted after calling Close. + // It must not be called concurrently with Write. + // It must not be called after calling CancelWrite. + io.Closer + // CancelWrite aborts sending on this stream. + // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. + // Write will unblock immediately, and future calls to Write will fail. + // When called multiple times or after closing the stream it is a no-op. + CancelWrite(StreamErrorCode) + // The Context is canceled as soon as the write-side of the stream is closed. + // This happens when Close() or CancelWrite() is called, or when the peer + // cancels the read-side of their stream. + Context() context.Context + // SetWriteDeadline sets the deadline for future Write calls + // and any currently-blocked Write call. + // Even if write times out, it may return n > 0, indicating that + // some data was successfully written. + // A zero value for t means Write will not time out. + SetWriteDeadline(t time.Time) error +} + +// A Connection is a QUIC connection between two peers. +// Calls to the connection (and to streams) can return the following types of errors: +// * ApplicationError: for errors triggered by the application running on top of QUIC +// * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer) +// * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error) +// * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error) +// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error) +// * VersionNegotiationError: returned by the client, when there's no version overlap between the peers +type Connection interface { + // AcceptStream returns the next stream opened by the peer, blocking until one is available. + // If the connection was closed due to a timeout, the error satisfies + // the net.Error interface, and Timeout() will be true. + AcceptStream(context.Context) (Stream, error) + // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. + // If the connection was closed due to a timeout, the error satisfies + // the net.Error interface, and Timeout() will be true. + AcceptUniStream(context.Context) (ReceiveStream, error) + // OpenStream opens a new bidirectional QUIC stream. + // There is no signaling to the peer about new streams: + // The peer can only accept the stream after data has been sent on the stream. + // If the error is non-nil, it satisfies the net.Error interface. + // When reaching the peer's stream limit, err.Temporary() will be true. + // If the connection was closed due to a timeout, Timeout() will be true. + OpenStream() (Stream, error) + // OpenStreamSync opens a new bidirectional QUIC stream. + // It blocks until a new stream can be opened. + // If the error is non-nil, it satisfies the net.Error interface. + // If the connection was closed due to a timeout, Timeout() will be true. + OpenStreamSync(context.Context) (Stream, error) + // OpenUniStream opens a new outgoing unidirectional QUIC stream. + // If the error is non-nil, it satisfies the net.Error interface. + // When reaching the peer's stream limit, Temporary() will be true. + // If the connection was closed due to a timeout, Timeout() will be true. + OpenUniStream() (SendStream, error) + // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream. + // It blocks until a new stream can be opened. + // If the error is non-nil, it satisfies the net.Error interface. + // If the connection was closed due to a timeout, Timeout() will be true. + OpenUniStreamSync(context.Context) (SendStream, error) + // LocalAddr returns the local address. + LocalAddr() net.Addr + // RemoteAddr returns the address of the peer. + RemoteAddr() net.Addr + // CloseWithError closes the connection with an error. + // The error string will be sent to the peer. + CloseWithError(ApplicationErrorCode, string) error + // The context is cancelled when the connection is closed. + Context() context.Context + // ConnectionState returns basic details about the QUIC connection. + // It blocks until the handshake completes. + // Warning: This API should not be considered stable and might change soon. + ConnectionState() ConnectionState + + // SendMessage sends a message as a datagram, as specified in RFC 9221. + SendMessage([]byte) error + // ReceiveMessage gets a message received in a datagram, as specified in RFC 9221. + ReceiveMessage() ([]byte, error) +} + +// An EarlyConnection is a connection that is handshaking. +// Data sent during the handshake is encrypted using the forward secure keys. +// When using client certificates, the client's identity is only verified +// after completion of the handshake. +type EarlyConnection interface { + Connection + + // HandshakeComplete blocks until the handshake completes (or fails). + // Data sent before completion of the handshake is encrypted with 1-RTT keys. + // Note that the client's identity hasn't been verified yet. + HandshakeComplete() context.Context + + NextConnection() Connection +} + +// Config contains all configuration data needed for a QUIC server or client. +type Config struct { + // The QUIC versions that can be negotiated. + // If not set, it uses all versions available. + Versions []VersionNumber + // The length of the connection ID in bytes. + // It can be 0, or any value between 4 and 18. + // If not set, the interpretation depends on where the Config is used: + // If used for dialing an address, a 0 byte connection ID will be used. + // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. + // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. + ConnectionIDLength int + // HandshakeIdleTimeout is the idle timeout before completion of the handshake. + // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. + // If this value is zero, the timeout is set to 5 seconds. + HandshakeIdleTimeout time.Duration + // MaxIdleTimeout is the maximum duration that may pass without any incoming network activity. + // The actual value for the idle timeout is the minimum of this value and the peer's. + // This value only applies after the handshake has completed. + // If the timeout is exceeded, the connection is closed. + // If this value is zero, the timeout is set to 30 seconds. + MaxIdleTimeout time.Duration + // AcceptToken determines if a Token is accepted. + // It is called with token = nil if the client didn't send a token. + // If not set, a default verification function is used: + // * it verifies that the address matches, and + // * if the token is a retry token, that it was issued within the last 5 seconds + // * else, that it was issued within the last 24 hours. + // This option is only valid for the server. + AcceptToken func(clientAddr net.Addr, token *Token) bool + // The TokenStore stores tokens received from the server. + // Tokens are used to skip address validation on future connection attempts. + // The key used to store tokens is the ServerName from the tls.Config, if set + // otherwise the token is associated with the server's IP address. + TokenStore TokenStore + // InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data. + // If the application is consuming data quickly enough, the flow control auto-tuning algorithm + // will increase the window up to MaxStreamReceiveWindow. + // If this value is zero, it will default to 512 KB. + InitialStreamReceiveWindow uint64 + // MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data. + // If this value is zero, it will default to 6 MB. + MaxStreamReceiveWindow uint64 + // InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data. + // If the application is consuming data quickly enough, the flow control auto-tuning algorithm + // will increase the window up to MaxConnectionReceiveWindow. + // If this value is zero, it will default to 512 KB. + InitialConnectionReceiveWindow uint64 + // MaxConnectionReceiveWindow is the connection-level flow control window for receiving data. + // If this value is zero, it will default to 15 MB. + MaxConnectionReceiveWindow uint64 + // AllowConnectionWindowIncrease is called every time the connection flow controller attempts + // to increase the connection flow control window. + // If set, the caller can prevent an increase of the window. Typically, it would do so to + // limit the memory usage. + // To avoid deadlocks, it is not valid to call other functions on the connection or on streams + // in this callback. + AllowConnectionWindowIncrease func(sess Connection, delta uint64) bool + // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. + // Values above 2^60 are invalid. + // If not set, it will default to 100. + // If set to a negative value, it doesn't allow any bidirectional streams. + MaxIncomingStreams int64 + // MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open. + // Values above 2^60 are invalid. + // If not set, it will default to 100. + // If set to a negative value, it doesn't allow any unidirectional streams. + MaxIncomingUniStreams int64 + // The StatelessResetKey is used to generate stateless reset tokens. + // If no key is configured, sending of stateless resets is disabled. + StatelessResetKey []byte + // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. + // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most + // every half of MaxIdleTimeout, whichever is smaller). + KeepAlivePeriod time.Duration + // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). + // Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. + // Note that if Path MTU discovery is causing issues on your system, please open a new issue + DisablePathMTUDiscovery bool + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. + // This can be useful if version information is exchanged out-of-band. + // It has no effect for a client. + DisableVersionNegotiationPackets bool + // See https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/. + // Datagrams will only be available when both peers enable datagram support. + EnableDatagrams bool + Tracer logging.Tracer +} + +// ConnectionState records basic details about a QUIC connection +type ConnectionState struct { + TLS handshake.ConnectionState + SupportsDatagrams bool +} + +// A Listener for incoming QUIC connections +type Listener interface { + // Close the server. All active connections will be closed. + Close() error + // Addr returns the local network addr that the server is listening on. + Addr() net.Addr + // Accept returns new connections. It should be called in a loop. + Accept(context.Context) (Connection, error) +} + +// An EarlyListener listens for incoming QUIC connections, +// and returns them before the handshake completes. +type EarlyListener interface { + // Close the server. All active connections will be closed. + Close() error + // Addr returns the local network addr that the server is listening on. + Addr() net.Addr + // Accept returns new early connections. It should be called in a loop. + Accept(context.Context) (EarlyConnection, error) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ack_eliciting.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ack_eliciting.go new file mode 100644 index 00000000..b8cd558a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ack_eliciting.go @@ -0,0 +1,20 @@ +package ackhandler + +import "github.com/lucas-clemente/quic-go/internal/wire" + +// IsFrameAckEliciting returns true if the frame is ack-eliciting. +func IsFrameAckEliciting(f wire.Frame) bool { + _, isAck := f.(*wire.AckFrame) + _, isConnectionClose := f.(*wire.ConnectionCloseFrame) + return !isAck && !isConnectionClose +} + +// HasAckElicitingFrames returns true if at least one frame is ack-eliciting. +func HasAckElicitingFrames(fs []Frame) bool { + for _, f := range fs { + if IsFrameAckEliciting(f.Frame) { + return true + } + } + return false +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ackhandler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ackhandler.go new file mode 100644 index 00000000..26291321 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ackhandler.go @@ -0,0 +1,21 @@ +package ackhandler + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/logging" +) + +// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler +func NewAckHandler( + initialPacketNumber protocol.PacketNumber, + initialMaxDatagramSize protocol.ByteCount, + rttStats *utils.RTTStats, + pers protocol.Perspective, + tracer logging.ConnectionTracer, + logger utils.Logger, + version protocol.VersionNumber, +) (SentPacketHandler, ReceivedPacketHandler) { + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, pers, tracer, logger) + return sph, newReceivedPacketHandler(sph, rttStats, logger, version) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/frame.go new file mode 100644 index 00000000..aed6038d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/frame.go @@ -0,0 +1,9 @@ +package ackhandler + +import "github.com/lucas-clemente/quic-go/internal/wire" + +type Frame struct { + wire.Frame // nil if the frame has already been acknowledged in another packet + OnLost func(wire.Frame) + OnAcked func(wire.Frame) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go new file mode 100644 index 00000000..32235f81 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go @@ -0,0 +1,3 @@ +package ackhandler + +//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go new file mode 100644 index 00000000..5777d97a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go @@ -0,0 +1,68 @@ +package ackhandler + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// A Packet is a packet +type Packet struct { + PacketNumber protocol.PacketNumber + Frames []Frame + LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK + Length protocol.ByteCount + EncryptionLevel protocol.EncryptionLevel + SendTime time.Time + + IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. + + includedInBytesInFlight bool + declaredLost bool + skippedPacket bool +} + +// SentPacketHandler handles ACKs received for outgoing packets +type SentPacketHandler interface { + // SentPacket may modify the packet + SentPacket(packet *Packet) + ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) + ReceivedBytes(protocol.ByteCount) + DropPackets(protocol.EncryptionLevel) + ResetForRetry() error + SetHandshakeConfirmed() + + // The SendMode determines if and what kind of packets can be sent. + SendMode() SendMode + // TimeUntilSend is the time when the next packet should be sent. + // It is used for pacing packets. + TimeUntilSend() time.Time + // HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment. + HasPacingBudget() bool + SetMaxDatagramSize(count protocol.ByteCount) + + // only to be called once the handshake is complete + QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ + + PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) + PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber + + GetLossDetectionTimeout() time.Time + OnLossDetectionTimeout() error +} + +type sentPacketTracker interface { + GetLowestPacketNotConfirmedAcked() protocol.PacketNumber + ReceivedPacket(protocol.EncryptionLevel) +} + +// ReceivedPacketHandler handles ACKs needed to send for incoming packets +type ReceivedPacketHandler interface { + IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool + ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error + DropPackets(protocol.EncryptionLevel) + + GetAlarmTimeout() time.Time + GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/mockgen.go new file mode 100644 index 00000000..e957d253 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/mockgen.go @@ -0,0 +1,3 @@ +package ackhandler + +//go:generate sh -c "../../mockgen_private.sh ackhandler mock_sent_packet_tracker_test.go github.com/lucas-clemente/quic-go/internal/ackhandler sentPacketTracker" diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go new file mode 100644 index 00000000..bb74f4ef --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package ackhandler + +// Linked list implementation from the Go standard library. + +// PacketElement is an element of a linked list. +type PacketElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *PacketElement + + // The list to which this element belongs. + list *PacketList + + // The value stored with this element. + Value Packet +} + +// Next returns the next list element or nil. +func (e *PacketElement) Next() *PacketElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *PacketElement) Prev() *PacketElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// PacketList is a linked list of Packets. +type PacketList struct { + root PacketElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *PacketList) Init() *PacketList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewPacketList returns an initialized list. +func NewPacketList() *PacketList { return new(PacketList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *PacketList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *PacketList) Front() *PacketElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *PacketList) Back() *PacketElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *PacketList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *PacketList) insert(e, at *PacketElement) *PacketElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement { + return l.insert(&PacketElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *PacketList) remove(e *PacketElement) *PacketElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *PacketList) Remove(e *PacketElement) Packet { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *PacketList) PushFront(v Packet) *PacketElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *PacketList) PushBack(v Packet) *PacketElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketList) MoveToFront(e *PacketElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketList) MoveToBack(e *PacketElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketList) MoveBefore(e, mark *PacketElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketList) MoveAfter(e, mark *PacketElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketList) PushBackList(other *PacketList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketList) PushFrontList(other *PacketList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_number_generator.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_number_generator.go new file mode 100644 index 00000000..7d58650c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_number_generator.go @@ -0,0 +1,76 @@ +package ackhandler + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type packetNumberGenerator interface { + Peek() protocol.PacketNumber + Pop() protocol.PacketNumber +} + +type sequentialPacketNumberGenerator struct { + next protocol.PacketNumber +} + +var _ packetNumberGenerator = &sequentialPacketNumberGenerator{} + +func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator { + return &sequentialPacketNumberGenerator{next: initial} +} + +func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber { + return p.next +} + +func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { + next := p.next + p.next++ + return next +} + +// The skippingPacketNumberGenerator generates the packet number for the next packet +// it randomly skips a packet number every averagePeriod packets (on average). +// It is guaranteed to never skip two consecutive packet numbers. +type skippingPacketNumberGenerator struct { + period protocol.PacketNumber + maxPeriod protocol.PacketNumber + + next protocol.PacketNumber + nextToSkip protocol.PacketNumber + + rng utils.Rand +} + +var _ packetNumberGenerator = &skippingPacketNumberGenerator{} + +func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator { + g := &skippingPacketNumberGenerator{ + next: initial, + period: initialPeriod, + maxPeriod: maxPeriod, + } + g.generateNewSkip() + return g +} + +func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { + return p.next +} + +func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { + next := p.next + p.next++ // generate a new packet number for the next packet + if p.next == p.nextToSkip { + p.next++ + p.generateNewSkip() + } + return next +} + +func (p *skippingPacketNumberGenerator) generateNewSkip() { + // make sure that there are never two consecutive packet numbers that are skipped + p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) + p.period = utils.MinPacketNumber(2*p.period, p.maxPeriod) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go new file mode 100644 index 00000000..89fb30d3 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go @@ -0,0 +1,136 @@ +package ackhandler + +import ( + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type receivedPacketHandler struct { + sentPackets sentPacketTracker + + initialPackets *receivedPacketTracker + handshakePackets *receivedPacketTracker + appDataPackets *receivedPacketTracker + + lowest1RTTPacket protocol.PacketNumber +} + +var _ ReceivedPacketHandler = &receivedPacketHandler{} + +func newReceivedPacketHandler( + sentPackets sentPacketTracker, + rttStats *utils.RTTStats, + logger utils.Logger, + version protocol.VersionNumber, +) ReceivedPacketHandler { + return &receivedPacketHandler{ + sentPackets: sentPackets, + initialPackets: newReceivedPacketTracker(rttStats, logger, version), + handshakePackets: newReceivedPacketTracker(rttStats, logger, version), + appDataPackets: newReceivedPacketTracker(rttStats, logger, version), + lowest1RTTPacket: protocol.InvalidPacketNumber, + } +} + +func (h *receivedPacketHandler) ReceivedPacket( + pn protocol.PacketNumber, + ecn protocol.ECN, + encLevel protocol.EncryptionLevel, + rcvTime time.Time, + shouldInstigateAck bool, +) error { + h.sentPackets.ReceivedPacket(encLevel) + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + case protocol.EncryptionHandshake: + h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + case protocol.Encryption0RTT: + if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { + return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) + } + h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + case protocol.Encryption1RTT: + if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { + h.lowest1RTTPacket = pn + } + h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) + h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + default: + panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) + } + return nil +} + +func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + //nolint:exhaustive // 1-RTT packet number space is never dropped. + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets = nil + case protocol.EncryptionHandshake: + h.handshakePackets = nil + case protocol.Encryption0RTT: + // Nothing to do here. + // If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted. + default: + panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) + } +} + +func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { + var initialAlarm, handshakeAlarm time.Time + if h.initialPackets != nil { + initialAlarm = h.initialPackets.GetAlarmTimeout() + } + if h.handshakePackets != nil { + handshakeAlarm = h.handshakePackets.GetAlarmTimeout() + } + oneRTTAlarm := h.appDataPackets.GetAlarmTimeout() + return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) +} + +func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { + var ack *wire.AckFrame + //nolint:exhaustive // 0-RTT packets can't contain ACK frames. + switch encLevel { + case protocol.EncryptionInitial: + if h.initialPackets != nil { + ack = h.initialPackets.GetAckFrame(onlyIfQueued) + } + case protocol.EncryptionHandshake: + if h.handshakePackets != nil { + ack = h.handshakePackets.GetAckFrame(onlyIfQueued) + } + case protocol.Encryption1RTT: + // 0-RTT packets can't contain ACK frames + return h.appDataPackets.GetAckFrame(onlyIfQueued) + default: + return nil + } + // For Initial and Handshake ACKs, the delay time is ignored by the receiver. + // Set it to 0 in order to save bytes. + if ack != nil { + ack.DelayTime = 0 + } + return ack +} + +func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool { + switch encLevel { + case protocol.EncryptionInitial: + if h.initialPackets != nil { + return h.initialPackets.IsPotentiallyDuplicate(pn) + } + case protocol.EncryptionHandshake: + if h.handshakePackets != nil { + return h.handshakePackets.IsPotentiallyDuplicate(pn) + } + case protocol.Encryption0RTT, protocol.Encryption1RTT: + return h.appDataPackets.IsPotentiallyDuplicate(pn) + } + panic("unexpected encryption level") +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go new file mode 100644 index 00000000..5a0391ef --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go @@ -0,0 +1,142 @@ +package ackhandler + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// The receivedPacketHistory stores if a packet number has already been received. +// It generates ACK ranges which can be used to assemble an ACK frame. +// It does not store packet contents. +type receivedPacketHistory struct { + ranges *utils.PacketIntervalList + + deletedBelow protocol.PacketNumber +} + +func newReceivedPacketHistory() *receivedPacketHistory { + return &receivedPacketHistory{ + ranges: utils.NewPacketIntervalList(), + } +} + +// ReceivedPacket registers a packet with PacketNumber p and updates the ranges +func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { + // ignore delayed packets, if we already deleted the range + if p < h.deletedBelow { + return false + } + isNew := h.addToRanges(p) + h.maybeDeleteOldRanges() + return isNew +} + +func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { + if h.ranges.Len() == 0 { + h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) + return true + } + + for el := h.ranges.Back(); el != nil; el = el.Prev() { + // p already included in an existing range. Nothing to do here + if p >= el.Value.Start && p <= el.Value.End { + return false + } + + if el.Value.End == p-1 { // extend a range at the end + el.Value.End = p + return true + } + if el.Value.Start == p+1 { // extend a range at the beginning + el.Value.Start = p + + prev := el.Prev() + if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges + prev.Value.End = el.Value.End + h.ranges.Remove(el) + } + return true + } + + // create a new range at the end + if p > el.Value.End { + h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el) + return true + } + } + + // create a new range at the beginning + h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front()) + return true +} + +// Delete old ranges, if we're tracking more than 500 of them. +// This is a DoS defense against a peer that sends us too many gaps. +func (h *receivedPacketHistory) maybeDeleteOldRanges() { + for h.ranges.Len() > protocol.MaxNumAckRanges { + h.ranges.Remove(h.ranges.Front()) + } +} + +// DeleteBelow deletes all entries below (but not including) p +func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) { + if p < h.deletedBelow { + return + } + h.deletedBelow = p + + nextEl := h.ranges.Front() + for el := h.ranges.Front(); nextEl != nil; el = nextEl { + nextEl = el.Next() + + if el.Value.End < p { // delete a whole range + h.ranges.Remove(el) + } else if p > el.Value.Start && p <= el.Value.End { + el.Value.Start = p + return + } else { // no ranges affected. Nothing to do + return + } + } +} + +// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame +func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange { + if h.ranges.Len() == 0 { + return nil + } + + ackRanges := make([]wire.AckRange, h.ranges.Len()) + i := 0 + for el := h.ranges.Back(); el != nil; el = el.Prev() { + ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End} + i++ + } + return ackRanges +} + +func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange { + ackRange := wire.AckRange{} + if h.ranges.Len() > 0 { + r := h.ranges.Back().Value + ackRange.Smallest = r.Start + ackRange.Largest = r.End + } + return ackRange +} + +func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool { + if p < h.deletedBelow { + return true + } + for el := h.ranges.Back(); el != nil; el = el.Prev() { + if p > el.Value.End { + return false + } + if p <= el.Value.End && p >= el.Value.Start { + return true + } + } + return false +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_tracker.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_tracker.go new file mode 100644 index 00000000..56e79269 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_tracker.go @@ -0,0 +1,196 @@ +package ackhandler + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// number of ack-eliciting packets received before sending an ack. +const packetsBeforeAck = 2 + +type receivedPacketTracker struct { + largestObserved protocol.PacketNumber + ignoreBelow protocol.PacketNumber + largestObservedReceivedTime time.Time + ect0, ect1, ecnce uint64 + + packetHistory *receivedPacketHistory + + maxAckDelay time.Duration + rttStats *utils.RTTStats + + hasNewAck bool // true as soon as we received an ack-eliciting new packet + ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets + + ackElicitingPacketsReceivedSinceLastAck int + ackAlarm time.Time + lastAck *wire.AckFrame + + logger utils.Logger + + version protocol.VersionNumber +} + +func newReceivedPacketTracker( + rttStats *utils.RTTStats, + logger utils.Logger, + version protocol.VersionNumber, +) *receivedPacketTracker { + return &receivedPacketTracker{ + packetHistory: newReceivedPacketHistory(), + maxAckDelay: protocol.MaxAckDelay, + rttStats: rttStats, + logger: logger, + version: version, + } +} + +func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) { + if packetNumber < h.ignoreBelow { + return + } + + isMissing := h.isMissing(packetNumber) + if packetNumber >= h.largestObserved { + h.largestObserved = packetNumber + h.largestObservedReceivedTime = rcvTime + } + + if isNew := h.packetHistory.ReceivedPacket(packetNumber); isNew && shouldInstigateAck { + h.hasNewAck = true + } + if shouldInstigateAck { + h.maybeQueueAck(packetNumber, rcvTime, isMissing) + } + switch ecn { + case protocol.ECNNon: + case protocol.ECT0: + h.ect0++ + case protocol.ECT1: + h.ect1++ + case protocol.ECNCE: + h.ecnce++ + } +} + +// IgnoreBelow sets a lower limit for acknowledging packets. +// Packets with packet numbers smaller than p will not be acked. +func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) { + if p <= h.ignoreBelow { + return + } + h.ignoreBelow = p + h.packetHistory.DeleteBelow(p) + if h.logger.Debug() { + h.logger.Debugf("\tIgnoring all packets below %d.", p) + } +} + +// isMissing says if a packet was reported missing in the last ACK. +func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool { + if h.lastAck == nil || p < h.ignoreBelow { + return false + } + return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) +} + +func (h *receivedPacketTracker) hasNewMissingPackets() bool { + if h.lastAck == nil { + return false + } + highestRange := h.packetHistory.GetHighestAckRange() + return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 +} + +// maybeQueueAck queues an ACK, if necessary. +func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { + // always acknowledge the first packet + if h.lastAck == nil { + if !h.ackQueued { + h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") + } + h.ackQueued = true + return + } + + if h.ackQueued { + return + } + + h.ackElicitingPacketsReceivedSinceLastAck++ + + // Send an ACK if this packet was reported missing in an ACK sent before. + // Ack decimation with reordering relies on the timer to send an ACK, but if + // missing packets we reported in the previous ack, send an ACK immediately. + if wasMissing { + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn) + } + h.ackQueued = true + } + + // send an ACK every 2 ack-eliciting packets + if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck { + if h.logger.Debug() { + h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck) + } + h.ackQueued = true + } else if h.ackAlarm.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay) + } + h.ackAlarm = rcvTime.Add(h.maxAckDelay) + } + + // Queue an ACK if there are new missing packets to report. + if h.hasNewMissingPackets() { + h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.") + h.ackQueued = true + } + + if h.ackQueued { + // cancel the ack alarm + h.ackAlarm = time.Time{} + } +} + +func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { + if !h.hasNewAck { + return nil + } + now := time.Now() + if onlyIfQueued { + if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { + return nil + } + if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { + h.logger.Debugf("Sending ACK because the ACK timer expired.") + } + } + + ack := &wire.AckFrame{ + AckRanges: h.packetHistory.GetAckRanges(), + // Make sure that the DelayTime is always positive. + // This is not guaranteed on systems that don't have a monotonic clock. + DelayTime: utils.MaxDuration(0, now.Sub(h.largestObservedReceivedTime)), + ECT0: h.ect0, + ECT1: h.ect1, + ECNCE: h.ecnce, + } + + h.lastAck = ack + h.ackAlarm = time.Time{} + h.ackQueued = false + h.hasNewAck = false + h.ackElicitingPacketsReceivedSinceLastAck = 0 + return ack +} + +func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } + +func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { + return h.packetHistory.IsPotentiallyDuplicate(pn) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go new file mode 100644 index 00000000..3d5fe560 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go @@ -0,0 +1,40 @@ +package ackhandler + +import "fmt" + +// The SendMode says what kind of packets can be sent. +type SendMode uint8 + +const ( + // SendNone means that no packets should be sent + SendNone SendMode = iota + // SendAck means an ACK-only packet should be sent + SendAck + // SendPTOInitial means that an Initial probe packet should be sent + SendPTOInitial + // SendPTOHandshake means that a Handshake probe packet should be sent + SendPTOHandshake + // SendPTOAppData means that an Application data probe packet should be sent + SendPTOAppData + // SendAny means that any packet should be sent + SendAny +) + +func (s SendMode) String() string { + switch s { + case SendNone: + return "none" + case SendAck: + return "ack" + case SendPTOInitial: + return "pto (Initial)" + case SendPTOHandshake: + return "pto (Handshake)" + case SendPTOAppData: + return "pto (Application Data)" + case SendAny: + return "any" + default: + return fmt.Sprintf("invalid send mode: %d", s) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go new file mode 100644 index 00000000..7df91f23 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go @@ -0,0 +1,838 @@ +package ackhandler + +import ( + "errors" + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/congestion" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/logging" +) + +const ( + // Maximum reordering in time space before time based loss detection considers a packet lost. + // Specified as an RTT multiplier. + timeThreshold = 9.0 / 8 + // Maximum reordering in packets before packet threshold loss detection considers a packet lost. + packetThreshold = 3 + // Before validating the client's address, the server won't send more than 3x bytes than it received. + amplificationFactor = 3 + // We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet. + minRTTAfterRetry = 5 * time.Millisecond +) + +type packetNumberSpace struct { + history *sentPacketHistory + pns packetNumberGenerator + + lossTime time.Time + lastAckElicitingPacketTime time.Time + + largestAcked protocol.PacketNumber + largestSent protocol.PacketNumber +} + +func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { + var pns packetNumberGenerator + if skipPNs { + pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) + } else { + pns = newSequentialPacketNumberGenerator(initialPN) + } + return &packetNumberSpace{ + history: newSentPacketHistory(rttStats), + pns: pns, + largestSent: protocol.InvalidPacketNumber, + largestAcked: protocol.InvalidPacketNumber, + } +} + +type sentPacketHandler struct { + initialPackets *packetNumberSpace + handshakePackets *packetNumberSpace + appDataPackets *packetNumberSpace + + // Do we know that the peer completed address validation yet? + // Always true for the server. + peerCompletedAddressValidation bool + bytesReceived protocol.ByteCount + bytesSent protocol.ByteCount + // Have we validated the peer's address yet? + // Always true for the client. + peerAddressValidated bool + + handshakeConfirmed bool + + // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived + // example: we send an ACK for packets 90-100 with packet number 20 + // once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101 + // Only applies to the application-data packet number space. + lowestNotConfirmedAcked protocol.PacketNumber + + ackedPackets []*Packet // to avoid allocations in detectAndRemoveAckedPackets + + bytesInFlight protocol.ByteCount + + congestion congestion.SendAlgorithmWithDebugInfos + rttStats *utils.RTTStats + + // The number of times a PTO has been sent without receiving an ack. + ptoCount uint32 + ptoMode SendMode + // The number of PTO probe packets that should be sent. + // Only applies to the application-data packet number space. + numProbesToSend int + + // The alarm timeout + alarm time.Time + + perspective protocol.Perspective + + tracer logging.ConnectionTracer + logger utils.Logger +} + +var ( + _ SentPacketHandler = &sentPacketHandler{} + _ sentPacketTracker = &sentPacketHandler{} +) + +func newSentPacketHandler( + initialPN protocol.PacketNumber, + initialMaxDatagramSize protocol.ByteCount, + rttStats *utils.RTTStats, + pers protocol.Perspective, + tracer logging.ConnectionTracer, + logger utils.Logger, +) *sentPacketHandler { + congestion := congestion.NewCubicSender( + congestion.DefaultClock{}, + rttStats, + initialMaxDatagramSize, + true, // use Reno + tracer, + ) + + return &sentPacketHandler{ + peerCompletedAddressValidation: pers == protocol.PerspectiveServer, + peerAddressValidated: pers == protocol.PerspectiveClient, + initialPackets: newPacketNumberSpace(initialPN, false, rttStats), + handshakePackets: newPacketNumberSpace(0, false, rttStats), + appDataPackets: newPacketNumberSpace(0, true, rttStats), + rttStats: rttStats, + congestion: congestion, + perspective: pers, + tracer: tracer, + logger: logger, + } +} + +func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { + if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial { + // This function is called when the crypto setup seals a Handshake packet. + // If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space + // before SentPacket() was called for that Initial packet. + return + } + h.dropPackets(encLevel) +} + +func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) { + if p.includedInBytesInFlight { + if p.Length > h.bytesInFlight { + panic("negative bytes_in_flight") + } + h.bytesInFlight -= p.Length + p.includedInBytesInFlight = false + } +} + +func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { + // The server won't await address validation after the handshake is confirmed. + // This applies even if we didn't receive an ACK for a Handshake packet. + if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake { + h.peerCompletedAddressValidation = true + } + // remove outstanding packets from bytes_in_flight + if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { + pnSpace := h.getPacketNumberSpace(encLevel) + pnSpace.history.Iterate(func(p *Packet) (bool, error) { + h.removeFromBytesInFlight(p) + return true, nil + }) + } + // drop the packet history + //nolint:exhaustive // Not every packet number space can be dropped. + switch encLevel { + case protocol.EncryptionInitial: + h.initialPackets = nil + case protocol.EncryptionHandshake: + h.handshakePackets = nil + case protocol.Encryption0RTT: + // This function is only called when 0-RTT is rejected, + // and not when the client drops 0-RTT keys when the handshake completes. + // When 0-RTT is rejected, all application data sent so far becomes invalid. + // Delete the packets from the history and remove them from bytes_in_flight. + h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { + if p.EncryptionLevel != protocol.Encryption0RTT { + return false, nil + } + h.removeFromBytesInFlight(p) + h.appDataPackets.history.Remove(p.PacketNumber) + return true, nil + }) + default: + panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) + } + if h.tracer != nil && h.ptoCount != 0 { + h.tracer.UpdatedPTOCount(0) + } + h.ptoCount = 0 + h.numProbesToSend = 0 + h.ptoMode = SendNone + h.setLossDetectionTimer() +} + +func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { + wasAmplificationLimit := h.isAmplificationLimited() + h.bytesReceived += n + if wasAmplificationLimit && !h.isAmplificationLimited() { + h.setLossDetectionTimer() + } +} + +func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) { + if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated { + h.peerAddressValidated = true + h.setLossDetectionTimer() + } +} + +func (h *sentPacketHandler) packetsInFlight() int { + packetsInFlight := h.appDataPackets.history.Len() + if h.handshakePackets != nil { + packetsInFlight += h.handshakePackets.history.Len() + } + if h.initialPackets != nil { + packetsInFlight += h.initialPackets.history.Len() + } + return packetsInFlight +} + +func (h *sentPacketHandler) SentPacket(packet *Packet) { + h.bytesSent += packet.Length + // For the client, drop the Initial packet number space when the first Handshake packet is sent. + if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { + h.dropPackets(protocol.EncryptionInitial) + } + isAckEliciting := h.sentPacketImpl(packet) + h.getPacketNumberSpace(packet.EncryptionLevel).history.SentPacket(packet, isAckEliciting) + if h.tracer != nil && isAckEliciting { + h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + } + if isAckEliciting || !h.peerCompletedAddressValidation { + h.setLossDetectionTimer() + } +} + +func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { + switch encLevel { + case protocol.EncryptionInitial: + return h.initialPackets + case protocol.EncryptionHandshake: + return h.handshakePackets + case protocol.Encryption0RTT, protocol.Encryption1RTT: + return h.appDataPackets + default: + panic("invalid packet number space") + } +} + +func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ { + pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel) + + if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { + for p := utils.MaxPacketNumber(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ { + h.logger.Debugf("Skipping packet number %d", p) + } + } + + pnSpace.largestSent = packet.PacketNumber + isAckEliciting := len(packet.Frames) > 0 + + if isAckEliciting { + pnSpace.lastAckElicitingPacketTime = packet.SendTime + packet.includedInBytesInFlight = true + h.bytesInFlight += packet.Length + if h.numProbesToSend > 0 { + h.numProbesToSend-- + } + } + h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isAckEliciting) + + return isAckEliciting +} + +func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) { + pnSpace := h.getPacketNumberSpace(encLevel) + + largestAcked := ack.LargestAcked() + if largestAcked > pnSpace.largestSent { + return false, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "received ACK for an unsent packet", + } + } + + pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked) + + // Servers complete address validation when a protected packet is received. + if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && + (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { + h.peerCompletedAddressValidation = true + h.logger.Debugf("Peer doesn't await address validation any longer.") + // Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets. + h.setLossDetectionTimer() + } + + priorInFlight := h.bytesInFlight + ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel) + if err != nil || len(ackedPackets) == 0 { + return false, err + } + // update the RTT, if the largest acked is newly acknowledged + if len(ackedPackets) > 0 { + if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() { + // don't use the ack delay for Initial and Handshake packets + var ackDelay time.Duration + if encLevel == protocol.Encryption1RTT { + ackDelay = utils.MinDuration(ack.DelayTime, h.rttStats.MaxAckDelay()) + } + h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) + if h.logger.Debug() { + h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) + } + h.congestion.MaybeExitSlowStart() + } + } + if err := h.detectLostPackets(rcvTime, encLevel); err != nil { + return false, err + } + var acked1RTTPacket bool + for _, p := range ackedPackets { + if p.includedInBytesInFlight && !p.declaredLost { + h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) + } + if p.EncryptionLevel == protocol.Encryption1RTT { + acked1RTTPacket = true + } + h.removeFromBytesInFlight(p) + } + + // Reset the pto_count unless the client is unsure if the server has validated the client's address. + if h.peerCompletedAddressValidation { + if h.tracer != nil && h.ptoCount != 0 { + h.tracer.UpdatedPTOCount(0) + } + h.ptoCount = 0 + } + h.numProbesToSend = 0 + + if h.tracer != nil { + h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + } + + pnSpace.history.DeleteOldPackets(rcvTime) + h.setLossDetectionTimer() + return acked1RTTPacket, nil +} + +func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber { + return h.lowestNotConfirmedAcked +} + +// Packets are returned in ascending packet number order. +func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) { + pnSpace := h.getPacketNumberSpace(encLevel) + h.ackedPackets = h.ackedPackets[:0] + ackRangeIndex := 0 + lowestAcked := ack.LowestAcked() + largestAcked := ack.LargestAcked() + err := pnSpace.history.Iterate(func(p *Packet) (bool, error) { + // Ignore packets below the lowest acked + if p.PacketNumber < lowestAcked { + return true, nil + } + // Break after largest acked is reached + if p.PacketNumber > largestAcked { + return false, nil + } + + if ack.HasMissingRanges() { + ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] + + for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 { + ackRangeIndex++ + ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex] + } + + if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range + return true, nil + } + if p.PacketNumber > ackRange.Largest { + return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest) + } + } + if p.skippedPacket { + return false, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel), + } + } + h.ackedPackets = append(h.ackedPackets, p) + return true, nil + }) + if h.logger.Debug() && len(h.ackedPackets) > 0 { + pns := make([]protocol.PacketNumber, len(h.ackedPackets)) + for i, p := range h.ackedPackets { + pns[i] = p.PacketNumber + } + h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns) + } + + for _, p := range h.ackedPackets { + if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { + h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1) + } + + for _, f := range p.Frames { + if f.OnAcked != nil { + f.OnAcked(f.Frame) + } + } + if err := pnSpace.history.Remove(p.PacketNumber); err != nil { + return nil, err + } + if h.tracer != nil { + h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) + } + } + + return h.ackedPackets, err +} + +func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) { + var encLevel protocol.EncryptionLevel + var lossTime time.Time + + if h.initialPackets != nil { + lossTime = h.initialPackets.lossTime + encLevel = protocol.EncryptionInitial + } + if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) { + lossTime = h.handshakePackets.lossTime + encLevel = protocol.EncryptionHandshake + } + if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) { + lossTime = h.appDataPackets.lossTime + encLevel = protocol.Encryption1RTT + } + return lossTime, encLevel +} + +// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime +func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { + // We only send application data probe packets once the handshake is confirmed, + // because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets. + if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() { + if h.peerCompletedAddressValidation { + return + } + t := time.Now().Add(h.rttStats.PTO(false) << h.ptoCount) + if h.initialPackets != nil { + return t, protocol.EncryptionInitial, true + } + return t, protocol.EncryptionHandshake, true + } + + if h.initialPackets != nil { + encLevel = protocol.EncryptionInitial + if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { + pto = t.Add(h.rttStats.PTO(false) << h.ptoCount) + } + } + if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { + t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(false) << h.ptoCount) + if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { + pto = t + encLevel = protocol.EncryptionHandshake + } + } + if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { + t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) + if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { + pto = t + encLevel = protocol.Encryption1RTT + } + } + return pto, encLevel, true +} + +func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool { + if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() { + return true + } + if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() { + return true + } + return false +} + +func (h *sentPacketHandler) hasOutstandingPackets() bool { + return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets() +} + +func (h *sentPacketHandler) setLossDetectionTimer() { + oldAlarm := h.alarm // only needed in case tracing is enabled + lossTime, encLevel := h.getLossTimeAndSpace() + if !lossTime.IsZero() { + // Early retransmit timer or time loss detection. + h.alarm = lossTime + if h.tracer != nil && h.alarm != oldAlarm { + h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) + } + return + } + + // Cancel the alarm if amplification limited. + if h.isAmplificationLimited() { + h.alarm = time.Time{} + if !oldAlarm.IsZero() { + h.logger.Debugf("Canceling loss detection timer. Amplification limited.") + if h.tracer != nil { + h.tracer.LossTimerCanceled() + } + } + return + } + + // Cancel the alarm if no packets are outstanding + if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation { + h.alarm = time.Time{} + if !oldAlarm.IsZero() { + h.logger.Debugf("Canceling loss detection timer. No packets in flight.") + if h.tracer != nil { + h.tracer.LossTimerCanceled() + } + } + return + } + + // PTO alarm + ptoTime, encLevel, ok := h.getPTOTimeAndSpace() + if !ok { + if !oldAlarm.IsZero() { + h.alarm = time.Time{} + h.logger.Debugf("Canceling loss detection timer. No PTO needed..") + if h.tracer != nil { + h.tracer.LossTimerCanceled() + } + } + return + } + h.alarm = ptoTime + if h.tracer != nil && h.alarm != oldAlarm { + h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) + } +} + +func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error { + pnSpace := h.getPacketNumberSpace(encLevel) + pnSpace.lossTime = time.Time{} + + maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) + lossDelay := time.Duration(timeThreshold * maxRTT) + + // Minimum time of granularity before packets are deemed lost. + lossDelay = utils.MaxDuration(lossDelay, protocol.TimerGranularity) + + // Packets sent before this time are deemed lost. + lostSendTime := now.Add(-lossDelay) + + priorInFlight := h.bytesInFlight + return pnSpace.history.Iterate(func(p *Packet) (bool, error) { + if p.PacketNumber > pnSpace.largestAcked { + return false, nil + } + if p.declaredLost || p.skippedPacket { + return true, nil + } + + var packetLost bool + if p.SendTime.Before(lostSendTime) { + packetLost = true + if h.logger.Debug() { + h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) + } + if h.tracer != nil { + h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) + } + } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold { + packetLost = true + if h.logger.Debug() { + h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) + } + if h.tracer != nil { + h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) + } + } else if pnSpace.lossTime.IsZero() { + // Note: This conditional is only entered once per call + lossTime := p.SendTime.Add(lossDelay) + if h.logger.Debug() { + h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime) + } + pnSpace.lossTime = lossTime + } + if packetLost { + p.declaredLost = true + // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted + h.removeFromBytesInFlight(p) + h.queueFramesForRetransmission(p) + if !p.IsPathMTUProbePacket { + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + } + } + return true, nil + }) +} + +func (h *sentPacketHandler) OnLossDetectionTimeout() error { + defer h.setLossDetectionTimer() + earliestLossTime, encLevel := h.getLossTimeAndSpace() + if !earliestLossTime.IsZero() { + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) + } + if h.tracer != nil { + h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) + } + // Early retransmit or time loss detection + return h.detectLostPackets(time.Now(), encLevel) + } + + // PTO + // When all outstanding are acknowledged, the alarm is canceled in + // setLossDetectionTimer. This doesn't reset the timer in the session though. + // When OnAlarm is called, we therefore need to make sure that there are + // actually packets outstanding. + if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation { + h.ptoCount++ + h.numProbesToSend++ + if h.initialPackets != nil { + h.ptoMode = SendPTOInitial + } else if h.handshakePackets != nil { + h.ptoMode = SendPTOHandshake + } else { + return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped") + } + return nil + } + + _, encLevel, ok := h.getPTOTimeAndSpace() + if !ok { + return nil + } + if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation { + return nil + } + h.ptoCount++ + if h.logger.Debug() { + h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) + } + if h.tracer != nil { + h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) + h.tracer.UpdatedPTOCount(h.ptoCount) + } + h.numProbesToSend += 2 + //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. + switch encLevel { + case protocol.EncryptionInitial: + h.ptoMode = SendPTOInitial + case protocol.EncryptionHandshake: + h.ptoMode = SendPTOHandshake + case protocol.Encryption1RTT: + // skip a packet number in order to elicit an immediate ACK + _ = h.PopPacketNumber(protocol.Encryption1RTT) + h.ptoMode = SendPTOAppData + default: + return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) + } + return nil +} + +func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { + return h.alarm +} + +func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { + pnSpace := h.getPacketNumberSpace(encLevel) + + var lowestUnacked protocol.PacketNumber + if p := pnSpace.history.FirstOutstanding(); p != nil { + lowestUnacked = p.PacketNumber + } else { + lowestUnacked = pnSpace.largestAcked + 1 + } + + pn := pnSpace.pns.Peek() + return pn, protocol.GetPacketNumberLengthForHeader(pn, lowestUnacked) +} + +func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { + return h.getPacketNumberSpace(encLevel).pns.Pop() +} + +func (h *sentPacketHandler) SendMode() SendMode { + numTrackedPackets := h.appDataPackets.history.Len() + if h.initialPackets != nil { + numTrackedPackets += h.initialPackets.history.Len() + } + if h.handshakePackets != nil { + numTrackedPackets += h.handshakePackets.history.Len() + } + + if h.isAmplificationLimited() { + h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent) + return SendNone + } + // Don't send any packets if we're keeping track of the maximum number of packets. + // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, + // we will stop sending out new data when reaching MaxOutstandingSentPackets, + // but still allow sending of retransmissions and ACKs. + if numTrackedPackets >= protocol.MaxTrackedSentPackets { + if h.logger.Debug() { + h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets) + } + return SendNone + } + if h.numProbesToSend > 0 { + return h.ptoMode + } + // Only send ACKs if we're congestion limited. + if !h.congestion.CanSend(h.bytesInFlight) { + if h.logger.Debug() { + h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow()) + } + return SendAck + } + if numTrackedPackets >= protocol.MaxOutstandingSentPackets { + if h.logger.Debug() { + h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets) + } + return SendAck + } + return SendAny +} + +func (h *sentPacketHandler) TimeUntilSend() time.Time { + return h.congestion.TimeUntilSend(h.bytesInFlight) +} + +func (h *sentPacketHandler) HasPacingBudget() bool { + return h.congestion.HasPacingBudget() +} + +func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { + h.congestion.SetMaxDatagramSize(s) +} + +func (h *sentPacketHandler) isAmplificationLimited() bool { + if h.peerAddressValidated { + return false + } + return h.bytesSent >= amplificationFactor*h.bytesReceived +} + +func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { + pnSpace := h.getPacketNumberSpace(encLevel) + p := pnSpace.history.FirstOutstanding() + if p == nil { + return false + } + h.queueFramesForRetransmission(p) + // TODO: don't declare the packet lost here. + // Keep track of acknowledged frames instead. + h.removeFromBytesInFlight(p) + p.declaredLost = true + return true +} + +func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { + if len(p.Frames) == 0 { + panic("no frames") + } + for _, f := range p.Frames { + f.OnLost(f.Frame) + } + p.Frames = nil +} + +func (h *sentPacketHandler) ResetForRetry() error { + h.bytesInFlight = 0 + var firstPacketSendTime time.Time + h.initialPackets.history.Iterate(func(p *Packet) (bool, error) { + if firstPacketSendTime.IsZero() { + firstPacketSendTime = p.SendTime + } + if p.declaredLost || p.skippedPacket { + return true, nil + } + h.queueFramesForRetransmission(p) + return true, nil + }) + // All application data packets sent at this point are 0-RTT packets. + // In the case of a Retry, we can assume that the server dropped all of them. + h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { + if !p.declaredLost && !p.skippedPacket { + h.queueFramesForRetransmission(p) + } + return true, nil + }) + + // Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial. + // Otherwise, we don't know which Initial the Retry was sent in response to. + if h.ptoCount == 0 { + // Don't set the RTT to a value lower than 5ms here. + now := time.Now() + h.rttStats.UpdateRTT(utils.MaxDuration(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) + if h.logger.Debug() { + h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) + } + if h.tracer != nil { + h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + } + } + h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats) + h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats) + oldAlarm := h.alarm + h.alarm = time.Time{} + if h.tracer != nil { + h.tracer.UpdatedPTOCount(0) + if !oldAlarm.IsZero() { + h.tracer.LossTimerCanceled() + } + } + h.ptoCount = 0 + return nil +} + +func (h *sentPacketHandler) SetHandshakeConfirmed() { + h.handshakeConfirmed = true + // We don't send PTOs for application data packets before the handshake completes. + // Make sure the timer is armed now, if necessary. + h.setLossDetectionTimer() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go new file mode 100644 index 00000000..36489367 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go @@ -0,0 +1,108 @@ +package ackhandler + +import ( + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type sentPacketHistory struct { + rttStats *utils.RTTStats + packetList *PacketList + packetMap map[protocol.PacketNumber]*PacketElement + highestSent protocol.PacketNumber +} + +func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { + return &sentPacketHistory{ + rttStats: rttStats, + packetList: NewPacketList(), + packetMap: make(map[protocol.PacketNumber]*PacketElement), + highestSent: protocol.InvalidPacketNumber, + } +} + +func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { + if p.PacketNumber <= h.highestSent { + panic("non-sequential packet number use") + } + // Skipped packet numbers. + for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ { + el := h.packetList.PushBack(Packet{ + PacketNumber: pn, + EncryptionLevel: p.EncryptionLevel, + SendTime: p.SendTime, + skippedPacket: true, + }) + h.packetMap[pn] = el + } + h.highestSent = p.PacketNumber + + if isAckEliciting { + el := h.packetList.PushBack(*p) + h.packetMap[p.PacketNumber] = el + } +} + +// Iterate iterates through all packets. +func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { + cont := true + var next *PacketElement + for el := h.packetList.Front(); cont && el != nil; el = next { + var err error + next = el.Next() + cont, err = cb(&el.Value) + if err != nil { + return err + } + } + return nil +} + +// FirstOutStanding returns the first outstanding packet. +func (h *sentPacketHistory) FirstOutstanding() *Packet { + for el := h.packetList.Front(); el != nil; el = el.Next() { + p := &el.Value + if !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket { + return p + } + } + return nil +} + +func (h *sentPacketHistory) Len() int { + return len(h.packetMap) +} + +func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { + el, ok := h.packetMap[p] + if !ok { + return fmt.Errorf("packet %d not found in sent packet history", p) + } + h.packetList.Remove(el) + delete(h.packetMap, p) + return nil +} + +func (h *sentPacketHistory) HasOutstandingPackets() bool { + return h.FirstOutstanding() != nil +} + +func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { + maxAge := 3 * h.rttStats.PTO(false) + var nextEl *PacketElement + for el := h.packetList.Front(); el != nil; el = nextEl { + nextEl = el.Next() + p := el.Value + if p.SendTime.After(now.Add(-maxAge)) { + break + } + if !p.skippedPacket && !p.declaredLost { // should only happen in the case of drastic RTT changes + continue + } + delete(h.packetMap, p.PacketNumber) + h.packetList.Remove(el) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go new file mode 100644 index 00000000..96b1c5aa --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go @@ -0,0 +1,25 @@ +package congestion + +import ( + "math" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// Bandwidth of a connection +type Bandwidth uint64 + +const infBandwidth Bandwidth = math.MaxUint64 + +const ( + // BitsPerSecond is 1 bit per second + BitsPerSecond Bandwidth = 1 + // BytesPerSecond is 1 byte per second + BytesPerSecond = 8 * BitsPerSecond +) + +// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta +func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth { + return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go new file mode 100644 index 00000000..405fae70 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go @@ -0,0 +1,18 @@ +package congestion + +import "time" + +// A Clock returns the current time +type Clock interface { + Now() time.Time +} + +// DefaultClock implements the Clock interface using the Go stdlib clock. +type DefaultClock struct{} + +var _ Clock = DefaultClock{} + +// Now gets the current time +func (DefaultClock) Now() time.Time { + return time.Now() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go new file mode 100644 index 00000000..beadd627 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go @@ -0,0 +1,214 @@ +package congestion + +import ( + "math" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// This cubic implementation is based on the one found in Chromiums's QUIC +// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. + +// Constants based on TCP defaults. +// The following constants are in 2^10 fractions of a second instead of ms to +// allow a 10 shift right to divide. + +// 1024*1024^3 (first 1024 is from 0.100^3) +// where 0.100 is 100 ms which is the scaling round trip time. +const ( + cubeScale = 40 + cubeCongestionWindowScale = 410 + cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize + // TODO: when re-enabling cubic, make sure to use the actual packet size here + maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) +) + +const defaultNumConnections = 1 + +// Default Cubic backoff factor +const beta float32 = 0.7 + +// Additional backoff factor when loss occurs in the concave part of the Cubic +// curve. This additional backoff factor is expected to give up bandwidth to +// new concurrent flows and speed up convergence. +const betaLastMax float32 = 0.85 + +// Cubic implements the cubic algorithm from TCP +type Cubic struct { + clock Clock + + // Number of connections to simulate. + numConnections int + + // Time when this cycle started, after last loss event. + epoch time.Time + + // Max congestion window used just before last loss event. + // Note: to improve fairness to other streams an additional back off is + // applied to this value if the new value is below our latest value. + lastMaxCongestionWindow protocol.ByteCount + + // Number of acked bytes since the cycle started (epoch). + ackedBytesCount protocol.ByteCount + + // TCP Reno equivalent congestion window in packets. + estimatedTCPcongestionWindow protocol.ByteCount + + // Origin point of cubic function. + originPointCongestionWindow protocol.ByteCount + + // Time to origin point of cubic function in 2^10 fractions of a second. + timeToOriginPoint uint32 + + // Last congestion window in packets computed by cubic function. + lastTargetCongestionWindow protocol.ByteCount +} + +// NewCubic returns a new Cubic instance +func NewCubic(clock Clock) *Cubic { + c := &Cubic{ + clock: clock, + numConnections: defaultNumConnections, + } + c.Reset() + return c +} + +// Reset is called after a timeout to reset the cubic state +func (c *Cubic) Reset() { + c.epoch = time.Time{} + c.lastMaxCongestionWindow = 0 + c.ackedBytesCount = 0 + c.estimatedTCPcongestionWindow = 0 + c.originPointCongestionWindow = 0 + c.timeToOriginPoint = 0 + c.lastTargetCongestionWindow = 0 +} + +func (c *Cubic) alpha() float32 { + // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that + // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. + // We derive the equivalent alpha for an N-connection emulation as: + b := c.beta() + return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) +} + +func (c *Cubic) beta() float32 { + // kNConnectionBeta is the backoff factor after loss for our N-connection + // emulation, which emulates the effective backoff of an ensemble of N + // TCP-Reno connections on a single loss event. The effective multiplier is + // computed as: + return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) +} + +func (c *Cubic) betaLastMax() float32 { + // betaLastMax is the additional backoff factor after loss for our + // N-connection emulation, which emulates the additional backoff of + // an ensemble of N TCP-Reno connections on a single loss event. The + // effective multiplier is computed as: + return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) +} + +// OnApplicationLimited is called on ack arrival when sender is unable to use +// the available congestion window. Resets Cubic state during quiescence. +func (c *Cubic) OnApplicationLimited() { + // When sender is not using the available congestion window, the window does + // not grow. But to be RTT-independent, Cubic assumes that the sender has been + // using the entire window during the time since the beginning of the current + // "epoch" (the end of the last loss recovery period). Since + // application-limited periods break this assumption, we reset the epoch when + // in such a period. This reset effectively freezes congestion window growth + // through application-limited periods and allows Cubic growth to continue + // when the entire window is being used. + c.epoch = time.Time{} +} + +// CongestionWindowAfterPacketLoss computes a new congestion window to use after +// a loss event. Returns the new congestion window in packets. The new +// congestion window is a multiplicative decrease of our current window. +func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount { + if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { + // We never reached the old max, so assume we are competing with another + // flow. Use our extra back off factor to allow the other flow to go up. + c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) + } else { + c.lastMaxCongestionWindow = currentCongestionWindow + } + c.epoch = time.Time{} // Reset time. + return protocol.ByteCount(float32(currentCongestionWindow) * c.beta()) +} + +// CongestionWindowAfterAck computes a new congestion window to use after a received ACK. +// Returns the new congestion window in packets. The new congestion window +// follows a cubic function that depends on the time passed since last +// packet loss. +func (c *Cubic) CongestionWindowAfterAck( + ackedBytes protocol.ByteCount, + currentCongestionWindow protocol.ByteCount, + delayMin time.Duration, + eventTime time.Time, +) protocol.ByteCount { + c.ackedBytesCount += ackedBytes + + if c.epoch.IsZero() { + // First ACK after a loss event. + c.epoch = eventTime // Start of epoch. + c.ackedBytesCount = ackedBytes // Reset count. + // Reset estimated_tcp_congestion_window_ to be in sync with cubic. + c.estimatedTCPcongestionWindow = currentCongestionWindow + if c.lastMaxCongestionWindow <= currentCongestionWindow { + c.timeToOriginPoint = 0 + c.originPointCongestionWindow = currentCongestionWindow + } else { + c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) + c.originPointCongestionWindow = c.lastMaxCongestionWindow + } + } + + // Change the time unit from microseconds to 2^10 fractions per second. Take + // the round trip time in account. This is done to allow us to use shift as a + // divide operator. + elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) + + // Right-shifts of negative, signed numbers have implementation-dependent + // behavior, so force the offset to be positive, as is done in the kernel. + offset := int64(c.timeToOriginPoint) - elapsedTime + if offset < 0 { + offset = -offset + } + + deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale + var targetCongestionWindow protocol.ByteCount + if elapsedTime > int64(c.timeToOriginPoint) { + targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow + } else { + targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow + } + // Limit the CWND increase to half the acked bytes. + targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + + // Increase the window by approximately Alpha * 1 MSS of bytes every + // time we ack an estimated tcp window of bytes. For small + // congestion windows (less than 25), the formula below will + // increase slightly slower than linearly per estimated tcp window + // of bytes. + c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) + c.ackedBytesCount = 0 + + // We have a new cubic congestion window. + c.lastTargetCongestionWindow = targetCongestionWindow + + // Compute target congestion_window based on cubic target and estimated TCP + // congestion_window, use highest (fastest). + if targetCongestionWindow < c.estimatedTCPcongestionWindow { + targetCongestionWindow = c.estimatedTCPcongestionWindow + } + return targetCongestionWindow +} + +// SetNumConnections sets the number of emulated connections +func (c *Cubic) SetNumConnections(n int) { + c.numConnections = n +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go new file mode 100644 index 00000000..059b8f6a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go @@ -0,0 +1,316 @@ +package congestion + +import ( + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/logging" +) + +const ( + // maxDatagramSize is the default maximum packet size used in the Linux TCP implementation. + // Used in QUIC for congestion window computations in bytes. + initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4) + maxBurstPackets = 3 + renoBeta = 0.7 // Reno backoff factor. + minCongestionWindowPackets = 2 + initialCongestionWindow = 32 +) + +type cubicSender struct { + hybridSlowStart HybridSlowStart + rttStats *utils.RTTStats + cubic *Cubic + pacer *pacer + clock Clock + + reno bool + + // Track the largest packet that has been sent. + largestSentPacketNumber protocol.PacketNumber + + // Track the largest packet that has been acked. + largestAckedPacketNumber protocol.PacketNumber + + // Track the largest packet number outstanding when a CWND cutback occurs. + largestSentAtLastCutback protocol.PacketNumber + + // Whether the last loss event caused us to exit slowstart. + // Used for stats collection of slowstartPacketsLost + lastCutbackExitedSlowstart bool + + // Congestion window in bytes. + congestionWindow protocol.ByteCount + + // Slow start congestion window in bytes, aka ssthresh. + slowStartThreshold protocol.ByteCount + + // ACK counter for the Reno implementation. + numAckedPackets uint64 + + initialCongestionWindow protocol.ByteCount + initialMaxCongestionWindow protocol.ByteCount + + maxDatagramSize protocol.ByteCount + + lastState logging.CongestionState + tracer logging.ConnectionTracer +} + +var ( + _ SendAlgorithm = &cubicSender{} + _ SendAlgorithmWithDebugInfos = &cubicSender{} +) + +// NewCubicSender makes a new cubic sender +func NewCubicSender( + clock Clock, + rttStats *utils.RTTStats, + initialMaxDatagramSize protocol.ByteCount, + reno bool, + tracer logging.ConnectionTracer, +) *cubicSender { + return newCubicSender( + clock, + rttStats, + reno, + initialMaxDatagramSize, + initialCongestionWindow*initialMaxDatagramSize, + protocol.MaxCongestionWindowPackets*initialMaxDatagramSize, + tracer, + ) +} + +func newCubicSender( + clock Clock, + rttStats *utils.RTTStats, + reno bool, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow protocol.ByteCount, + tracer logging.ConnectionTracer, +) *cubicSender { + c := &cubicSender{ + rttStats: rttStats, + largestSentPacketNumber: protocol.InvalidPacketNumber, + largestAckedPacketNumber: protocol.InvalidPacketNumber, + largestSentAtLastCutback: protocol.InvalidPacketNumber, + initialCongestionWindow: initialCongestionWindow, + initialMaxCongestionWindow: initialMaxCongestionWindow, + congestionWindow: initialCongestionWindow, + slowStartThreshold: protocol.MaxByteCount, + cubic: NewCubic(clock), + clock: clock, + reno: reno, + tracer: tracer, + maxDatagramSize: initialMaxDatagramSize, + } + c.pacer = newPacer(c.BandwidthEstimate) + if c.tracer != nil { + c.lastState = logging.CongestionStateSlowStart + c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) + } + return c +} + +// TimeUntilSend returns when the next packet should be sent. +func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time { + return c.pacer.TimeUntilSend() +} + +func (c *cubicSender) HasPacingBudget() bool { + return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize +} + +func (c *cubicSender) maxCongestionWindow() protocol.ByteCount { + return c.maxDatagramSize * protocol.MaxCongestionWindowPackets +} + +func (c *cubicSender) minCongestionWindow() protocol.ByteCount { + return c.maxDatagramSize * minCongestionWindowPackets +} + +func (c *cubicSender) OnPacketSent( + sentTime time.Time, + _ protocol.ByteCount, + packetNumber protocol.PacketNumber, + bytes protocol.ByteCount, + isRetransmittable bool, +) { + c.pacer.SentPacket(sentTime, bytes) + if !isRetransmittable { + return + } + c.largestSentPacketNumber = packetNumber + c.hybridSlowStart.OnPacketSent(packetNumber) +} + +func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool { + return bytesInFlight < c.GetCongestionWindow() +} + +func (c *cubicSender) InRecovery() bool { + return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback +} + +func (c *cubicSender) InSlowStart() bool { + return c.GetCongestionWindow() < c.slowStartThreshold +} + +func (c *cubicSender) GetCongestionWindow() protocol.ByteCount { + return c.congestionWindow +} + +func (c *cubicSender) MaybeExitSlowStart() { + if c.InSlowStart() && + c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { + // exit slow start + c.slowStartThreshold = c.congestionWindow + c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) + } +} + +func (c *cubicSender) OnPacketAcked( + ackedPacketNumber protocol.PacketNumber, + ackedBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, + eventTime time.Time, +) { + c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) + if c.InRecovery() { + return + } + c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) + if c.InSlowStart() { + c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) + } +} + +func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { + // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets + // already sent should be treated as a single loss event, since it's expected. + if packetNumber <= c.largestSentAtLastCutback { + return + } + c.lastCutbackExitedSlowstart = c.InSlowStart() + c.maybeTraceStateChange(logging.CongestionStateRecovery) + + if c.reno { + c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta) + } else { + c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) + } + if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { + c.congestionWindow = minCwnd + } + c.slowStartThreshold = c.congestionWindow + c.largestSentAtLastCutback = c.largestSentPacketNumber + // reset packet count from congestion avoidance mode. We start + // counting again when we're out of recovery. + c.numAckedPackets = 0 +} + +// Called when we receive an ack. Normal TCP tracks how many packets one ack +// represents, but quic has a separate ack for each packet. +func (c *cubicSender) maybeIncreaseCwnd( + _ protocol.PacketNumber, + ackedBytes protocol.ByteCount, + priorInFlight protocol.ByteCount, + eventTime time.Time, +) { + // Do not increase the congestion window unless the sender is close to using + // the current window. + if !c.isCwndLimited(priorInFlight) { + c.cubic.OnApplicationLimited() + c.maybeTraceStateChange(logging.CongestionStateApplicationLimited) + return + } + if c.congestionWindow >= c.maxCongestionWindow() { + return + } + if c.InSlowStart() { + // TCP slow start, exponential growth, increase by one for each ACK. + c.congestionWindow += c.maxDatagramSize + c.maybeTraceStateChange(logging.CongestionStateSlowStart) + return + } + // Congestion avoidance + c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) + if c.reno { + // Classic Reno congestion avoidance. + c.numAckedPackets++ + if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { + c.congestionWindow += c.maxDatagramSize + c.numAckedPackets = 0 + } + } else { + c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) + } +} + +func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool { + congestionWindow := c.GetCongestionWindow() + if bytesInFlight >= congestionWindow { + return true + } + availableBytes := congestionWindow - bytesInFlight + slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 + return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize +} + +// BandwidthEstimate returns the current bandwidth estimate +func (c *cubicSender) BandwidthEstimate() Bandwidth { + srtt := c.rttStats.SmoothedRTT() + if srtt == 0 { + // If we haven't measured an rtt, the bandwidth estimate is unknown. + return infBandwidth + } + return BandwidthFromDelta(c.GetCongestionWindow(), srtt) +} + +// OnRetransmissionTimeout is called on an retransmission timeout +func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { + c.largestSentAtLastCutback = protocol.InvalidPacketNumber + if !packetsRetransmitted { + return + } + c.hybridSlowStart.Restart() + c.cubic.Reset() + c.slowStartThreshold = c.congestionWindow / 2 + c.congestionWindow = c.minCongestionWindow() +} + +// OnConnectionMigration is called when the connection is migrated (?) +func (c *cubicSender) OnConnectionMigration() { + c.hybridSlowStart.Restart() + c.largestSentPacketNumber = protocol.InvalidPacketNumber + c.largestAckedPacketNumber = protocol.InvalidPacketNumber + c.largestSentAtLastCutback = protocol.InvalidPacketNumber + c.lastCutbackExitedSlowstart = false + c.cubic.Reset() + c.numAckedPackets = 0 + c.congestionWindow = c.initialCongestionWindow + c.slowStartThreshold = c.initialMaxCongestionWindow +} + +func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { + if c.tracer == nil || new == c.lastState { + return + } + c.tracer.UpdatedCongestionState(new) + c.lastState = new +} + +func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) { + if s < c.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) + } + cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() + c.maxDatagramSize = s + if cwndIsMinCwnd { + c.congestionWindow = c.minCongestionWindow() + } + c.pacer.SetMaxDatagramSize(s) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go new file mode 100644 index 00000000..b5ae3d5e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go @@ -0,0 +1,113 @@ +package congestion + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// Note(pwestin): the magic clamping numbers come from the original code in +// tcp_cubic.c. +const hybridStartLowWindow = protocol.ByteCount(16) + +// Number of delay samples for detecting the increase of delay. +const hybridStartMinSamples = uint32(8) + +// Exit slow start if the min rtt has increased by more than 1/8th. +const hybridStartDelayFactorExp = 3 // 2^3 = 8 +// The original paper specifies 2 and 8ms, but those have changed over time. +const ( + hybridStartDelayMinThresholdUs = int64(4000) + hybridStartDelayMaxThresholdUs = int64(16000) +) + +// HybridSlowStart implements the TCP hybrid slow start algorithm +type HybridSlowStart struct { + endPacketNumber protocol.PacketNumber + lastSentPacketNumber protocol.PacketNumber + started bool + currentMinRTT time.Duration + rttSampleCount uint32 + hystartFound bool +} + +// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. +func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) { + s.endPacketNumber = lastSent + s.currentMinRTT = 0 + s.rttSampleCount = 0 + s.started = true +} + +// IsEndOfRound returns true if this ack is the last packet number of our current slow start round. +func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool { + return s.endPacketNumber < ack +} + +// ShouldExitSlowStart should be called on every new ack frame, since a new +// RTT measurement can be made then. +// rtt: the RTT for this ack packet. +// minRTT: is the lowest delay (RTT) we have seen during the session. +// congestionWindow: the congestion window in packets. +func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool { + if !s.started { + // Time to start the hybrid slow start. + s.StartReceiveRound(s.lastSentPacketNumber) + } + if s.hystartFound { + return true + } + // Second detection parameter - delay increase detection. + // Compare the minimum delay (s.currentMinRTT) of the current + // burst of packets relative to the minimum delay during the session. + // Note: we only look at the first few(8) packets in each burst, since we + // only want to compare the lowest RTT of the burst relative to previous + // bursts. + s.rttSampleCount++ + if s.rttSampleCount <= hybridStartMinSamples { + if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { + s.currentMinRTT = latestRTT + } + } + // We only need to check this once per round. + if s.rttSampleCount == hybridStartMinSamples { + // Divide minRTT by 8 to get a rtt increase threshold for exiting. + minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) + // Ensure the rtt threshold is never less than 2ms or more than 16ms. + minRTTincreaseThresholdUs = utils.MinInt64(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) + minRTTincreaseThreshold := time.Duration(utils.MaxInt64(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond + + if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { + s.hystartFound = true + } + } + // Exit from slow start if the cwnd is greater than 16 and + // increasing delay is found. + return congestionWindow >= hybridStartLowWindow && s.hystartFound +} + +// OnPacketSent is called when a packet was sent +func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) { + s.lastSentPacketNumber = packetNumber +} + +// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end +// the round when the final packet of the burst is received and start it on +// the next incoming ack. +func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) { + if s.IsEndOfRound(ackedPacketNumber) { + s.started = false + } +} + +// Started returns true if started +func (s *HybridSlowStart) Started() bool { + return s.started +} + +// Restart the slow start phase +func (s *HybridSlowStart) Restart() { + s.started = false + s.hystartFound = false +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go new file mode 100644 index 00000000..5157383f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go @@ -0,0 +1,28 @@ +package congestion + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A SendAlgorithm performs congestion control +type SendAlgorithm interface { + TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time + HasPacingBudget() bool + OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) + CanSend(bytesInFlight protocol.ByteCount) bool + MaybeExitSlowStart() + OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) + OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) + OnRetransmissionTimeout(packetsRetransmitted bool) + SetMaxDatagramSize(protocol.ByteCount) +} + +// A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos +type SendAlgorithmWithDebugInfos interface { + SendAlgorithm + InSlowStart() bool + InRecovery() bool + GetCongestionWindow() protocol.ByteCount +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/pacer.go b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/pacer.go new file mode 100644 index 00000000..7ec4d8f5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/congestion/pacer.go @@ -0,0 +1,77 @@ +package congestion + +import ( + "math" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +const maxBurstSizePackets = 10 + +// The pacer implements a token bucket pacing algorithm. +type pacer struct { + budgetAtLastSent protocol.ByteCount + maxDatagramSize protocol.ByteCount + lastSentTime time.Time + getAdjustedBandwidth func() uint64 // in bytes/s +} + +func newPacer(getBandwidth func() Bandwidth) *pacer { + p := &pacer{ + maxDatagramSize: initialMaxDatagramSize, + getAdjustedBandwidth: func() uint64 { + // Bandwidth is in bits/s. We need the value in bytes/s. + bw := uint64(getBandwidth() / BytesPerSecond) + // Use a slightly higher value than the actual measured bandwidth. + // RTT variations then won't result in under-utilization of the congestion window. + // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, + // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. + return bw * 5 / 4 + }, + } + p.budgetAtLastSent = p.maxBurstSize() + return p +} + +func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *pacer) Budget(now time.Time) protocol.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (protocol.ByteCount(p.getAdjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + return utils.MinByteCount(p.maxBurstSize(), budget) +} + +func (p *pacer) maxBurstSize() protocol.ByteCount { + return utils.MaxByteCount( + protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, + maxBurstSizePackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value of time.Time if a packet can be sent immediately. +func (p *pacer) TimeUntilSend() time.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return time.Time{} + } + return p.lastSentTime.Add(utils.MaxDuration( + protocol.MinPacingDelay, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, + )) +} + +func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) { + p.maxDatagramSize = s +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go new file mode 100644 index 00000000..2bf14fdc --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go @@ -0,0 +1,125 @@ +package flowcontrol + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type baseFlowController struct { + // for sending data + bytesSent protocol.ByteCount + sendWindow protocol.ByteCount + lastBlockedAt protocol.ByteCount + + // for receiving data + //nolint:structcheck // The mutex is used both by the stream and the connection flow controller + mutex sync.Mutex + bytesRead protocol.ByteCount + highestReceived protocol.ByteCount + receiveWindow protocol.ByteCount + receiveWindowSize protocol.ByteCount + maxReceiveWindowSize protocol.ByteCount + + allowWindowIncrease func(size protocol.ByteCount) bool + + epochStartTime time.Time + epochStartOffset protocol.ByteCount + rttStats *utils.RTTStats + + logger utils.Logger +} + +// IsNewlyBlocked says if it is newly blocked by flow control. +// For every offset, it only returns true once. +// If it is blocked, the offset is returned. +func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { + if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt { + return false, 0 + } + c.lastBlockedAt = c.sendWindow + return true, c.sendWindow +} + +func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { + c.bytesSent += n +} + +// UpdateSendWindow is be called after receiving a MAX_{STREAM_}DATA frame. +func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { + if offset > c.sendWindow { + c.sendWindow = offset + } +} + +func (c *baseFlowController) sendWindowSize() protocol.ByteCount { + // this only happens during connection establishment, when data is sent before we receive the peer's transport parameters + if c.bytesSent > c.sendWindow { + return 0 + } + return c.sendWindow - c.bytesSent +} + +// needs to be called with locked mutex +func (c *baseFlowController) addBytesRead(n protocol.ByteCount) { + // pretend we sent a WindowUpdate when reading the first byte + // this way auto-tuning of the window size already works for the first WindowUpdate + if c.bytesRead == 0 { + c.startNewAutoTuningEpoch(time.Now()) + } + c.bytesRead += n +} + +func (c *baseFlowController) hasWindowUpdate() bool { + bytesRemaining := c.receiveWindow - c.bytesRead + // update the window when more than the threshold was consumed + return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold)) +} + +// getWindowUpdate updates the receive window, if necessary +// it returns the new offset +func (c *baseFlowController) getWindowUpdate() protocol.ByteCount { + if !c.hasWindowUpdate() { + return 0 + } + + c.maybeAdjustWindowSize() + c.receiveWindow = c.bytesRead + c.receiveWindowSize + return c.receiveWindow +} + +// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often. +// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing. +func (c *baseFlowController) maybeAdjustWindowSize() { + bytesReadInEpoch := c.bytesRead - c.epochStartOffset + // don't do anything if less than half the window has been consumed + if bytesReadInEpoch <= c.receiveWindowSize/2 { + return + } + rtt := c.rttStats.SmoothedRTT() + if rtt == 0 { + return + } + + fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize) + now := time.Now() + if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { + // window is consumed too fast, try to increase the window size + newSize := utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize) + if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) { + c.receiveWindowSize = newSize + } + } + c.startNewAutoTuningEpoch(now) +} + +func (c *baseFlowController) startNewAutoTuningEpoch(now time.Time) { + c.epochStartTime = now + c.epochStartOffset = c.bytesRead +} + +func (c *baseFlowController) checkFlowControlViolation() bool { + return c.highestReceived > c.receiveWindow +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go new file mode 100644 index 00000000..6bf2241b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -0,0 +1,112 @@ +package flowcontrol + +import ( + "errors" + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type connectionFlowController struct { + baseFlowController + + queueWindowUpdate func() +} + +var _ ConnectionFlowController = &connectionFlowController{} + +// NewConnectionFlowController gets a new flow controller for the connection +// It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0. +func NewConnectionFlowController( + receiveWindow protocol.ByteCount, + maxReceiveWindow protocol.ByteCount, + queueWindowUpdate func(), + allowWindowIncrease func(size protocol.ByteCount) bool, + rttStats *utils.RTTStats, + logger utils.Logger, +) ConnectionFlowController { + return &connectionFlowController{ + baseFlowController: baseFlowController{ + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + allowWindowIncrease: allowWindowIncrease, + logger: logger, + }, + queueWindowUpdate: queueWindowUpdate, + } +} + +func (c *connectionFlowController) SendWindowSize() protocol.ByteCount { + return c.baseFlowController.sendWindowSize() +} + +// IncrementHighestReceived adds an increment to the highestReceived value +func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.highestReceived += increment + if c.checkFlowControlViolation() { + return &qerr.TransportError{ + ErrorCode: qerr.FlowControlError, + ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow), + } + } + return nil +} + +func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) { + c.mutex.Lock() + c.baseFlowController.addBytesRead(n) + shouldQueueWindowUpdate := c.hasWindowUpdate() + c.mutex.Unlock() + if shouldQueueWindowUpdate { + c.queueWindowUpdate() + } +} + +func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount { + c.mutex.Lock() + oldWindowSize := c.receiveWindowSize + offset := c.baseFlowController.getWindowUpdate() + if oldWindowSize < c.receiveWindowSize { + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10)) + } + c.mutex.Unlock() + return offset +} + +// EnsureMinimumWindowSize sets a minimum window size +// it should make sure that the connection-level window is increased when a stream-level window grows +func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) { + c.mutex.Lock() + if inc > c.receiveWindowSize { + c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) + newSize := utils.MinByteCount(inc, c.maxReceiveWindowSize) + if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { + c.receiveWindowSize = newSize + } + c.startNewAutoTuningEpoch(time.Now()) + } + c.mutex.Unlock() +} + +// Reset rests the flow controller. This happens when 0-RTT is rejected. +// All stream data is invalidated, it's if we had never opened a stream and never sent any data. +// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet. +func (c *connectionFlowController) Reset() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() { + return errors.New("flow controller reset after reading data") + } + c.bytesSent = 0 + c.lastBlockedAt = 0 + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go new file mode 100644 index 00000000..1eeaee9f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go @@ -0,0 +1,42 @@ +package flowcontrol + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +type flowController interface { + // for sending + SendWindowSize() protocol.ByteCount + UpdateSendWindow(protocol.ByteCount) + AddBytesSent(protocol.ByteCount) + // for receiving + AddBytesRead(protocol.ByteCount) + GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary + IsNewlyBlocked() (bool, protocol.ByteCount) +} + +// A StreamFlowController is a flow controller for a QUIC stream. +type StreamFlowController interface { + flowController + // for receiving + // UpdateHighestReceived should be called when a new highest offset is received + // final has to be to true if this is the final offset of the stream, + // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame + UpdateHighestReceived(offset protocol.ByteCount, final bool) error + // Abandon should be called when reading from the stream is aborted early, + // and there won't be any further calls to AddBytesRead. + Abandon() +} + +// The ConnectionFlowController is the flow controller for the connection. +type ConnectionFlowController interface { + flowController + Reset() error +} + +type connectionFlowControllerI interface { + ConnectionFlowController + // The following two methods are not supposed to be called from outside this packet, but are needed internally + // for sending + EnsureMinimumWindowSize(protocol.ByteCount) + // for receiving + IncrementHighestReceived(protocol.ByteCount) error +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go new file mode 100644 index 00000000..aa66aef1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -0,0 +1,149 @@ +package flowcontrol + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type streamFlowController struct { + baseFlowController + + streamID protocol.StreamID + + queueWindowUpdate func() + + connection connectionFlowControllerI + + receivedFinalOffset bool +} + +var _ StreamFlowController = &streamFlowController{} + +// NewStreamFlowController gets a new flow controller for a stream +func NewStreamFlowController( + streamID protocol.StreamID, + cfc ConnectionFlowController, + receiveWindow protocol.ByteCount, + maxReceiveWindow protocol.ByteCount, + initialSendWindow protocol.ByteCount, + queueWindowUpdate func(protocol.StreamID), + rttStats *utils.RTTStats, + logger utils.Logger, +) StreamFlowController { + return &streamFlowController{ + streamID: streamID, + connection: cfc.(connectionFlowControllerI), + queueWindowUpdate: func() { queueWindowUpdate(streamID) }, + baseFlowController: baseFlowController{ + rttStats: rttStats, + receiveWindow: receiveWindow, + receiveWindowSize: receiveWindow, + maxReceiveWindowSize: maxReceiveWindow, + sendWindow: initialSendWindow, + logger: logger, + }, + } +} + +// UpdateHighestReceived updates the highestReceived value, if the offset is higher. +func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error { + // If the final offset for this stream is already known, check for consistency. + if c.receivedFinalOffset { + // If we receive another final offset, check that it's the same. + if final && offset != c.highestReceived { + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset), + } + } + // Check that the offset is below the final offset. + if offset > c.highestReceived { + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived), + } + } + } + + if final { + c.receivedFinalOffset = true + } + if offset == c.highestReceived { + return nil + } + // A higher offset was received before. + // This can happen due to reordering. + if offset <= c.highestReceived { + if final { + return &qerr.TransportError{ + ErrorCode: qerr.FinalSizeError, + ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived), + } + } + return nil + } + + increment := offset - c.highestReceived + c.highestReceived = offset + if c.checkFlowControlViolation() { + return &qerr.TransportError{ + ErrorCode: qerr.FlowControlError, + ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow), + } + } + return c.connection.IncrementHighestReceived(increment) +} + +func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) { + c.mutex.Lock() + c.baseFlowController.addBytesRead(n) + shouldQueueWindowUpdate := c.shouldQueueWindowUpdate() + c.mutex.Unlock() + if shouldQueueWindowUpdate { + c.queueWindowUpdate() + } + c.connection.AddBytesRead(n) +} + +func (c *streamFlowController) Abandon() { + c.mutex.Lock() + unread := c.highestReceived - c.bytesRead + c.mutex.Unlock() + if unread > 0 { + c.connection.AddBytesRead(unread) + } +} + +func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { + c.baseFlowController.AddBytesSent(n) + c.connection.AddBytesSent(n) +} + +func (c *streamFlowController) SendWindowSize() protocol.ByteCount { + return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) +} + +func (c *streamFlowController) shouldQueueWindowUpdate() bool { + return !c.receivedFinalOffset && c.hasWindowUpdate() +} + +func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount { + // If we already received the final offset for this stream, the peer won't need any additional flow control credit. + if c.receivedFinalOffset { + return 0 + } + + // Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler + c.mutex.Lock() + oldWindowSize := c.receiveWindowSize + offset := c.baseFlowController.getWindowUpdate() + if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size + c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10)) + c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier)) + } + c.mutex.Unlock() + return offset +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go new file mode 100644 index 00000000..03b03928 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go @@ -0,0 +1,161 @@ +package handshake + +import ( + "crypto/cipher" + "encoding/binary" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { + keyLabel := hkdfLabelKeyV1 + ivLabel := hkdfLabelIVV1 + if v == protocol.Version2 { + keyLabel = hkdfLabelKeyV2 + ivLabel = hkdfLabelIVV2 + } + key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen) + iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen()) + return suite.AEAD(key, iv) +} + +type longHeaderSealer struct { + aead cipher.AEAD + headerProtector headerProtector + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var _ LongHeaderSealer = &longHeaderSealer{} + +func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { + return &longHeaderSealer{ + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), + } +} + +func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return s.aead.Seal(dst, s.nonceBuf, src, ad) +} + +func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + s.headerProtector.EncryptHeader(sample, firstByte, pnBytes) +} + +func (s *longHeaderSealer) Overhead() int { + return s.aead.Overhead() +} + +type longHeaderOpener struct { + aead cipher.AEAD + headerProtector headerProtector + highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var _ LongHeaderOpener = &longHeaderOpener{} + +func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { + return &longHeaderOpener{ + aead: aead, + headerProtector: headerProtector, + nonceBuf: make([]byte, aead.NonceSize()), + } +} + +func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { + return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN) +} + +func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) + if err == nil { + o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn) + } else { + err = ErrDecryptionFailed + } + return dec, err +} + +func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { + o.headerProtector.DecryptHeader(sample, firstByte, pnBytes) +} + +type handshakeSealer struct { + LongHeaderSealer + + dropInitialKeys func() + dropped bool +} + +func newHandshakeSealer( + aead cipher.AEAD, + headerProtector headerProtector, + dropInitialKeys func(), + perspective protocol.Perspective, +) LongHeaderSealer { + sealer := newLongHeaderSealer(aead, headerProtector) + // The client drops Initial keys when sending the first Handshake packet. + if perspective == protocol.PerspectiveServer { + return sealer + } + return &handshakeSealer{ + LongHeaderSealer: sealer, + dropInitialKeys: dropInitialKeys, + } +} + +func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + data := s.LongHeaderSealer.Seal(dst, src, pn, ad) + if !s.dropped { + s.dropInitialKeys() + s.dropped = true + } + return data +} + +type handshakeOpener struct { + LongHeaderOpener + + dropInitialKeys func() + dropped bool +} + +func newHandshakeOpener( + aead cipher.AEAD, + headerProtector headerProtector, + dropInitialKeys func(), + perspective protocol.Perspective, +) LongHeaderOpener { + opener := newLongHeaderOpener(aead, headerProtector) + // The server drops Initial keys when first successfully processing a Handshake packet. + if perspective == protocol.PerspectiveClient { + return opener + } + return &handshakeOpener{ + LongHeaderOpener: opener, + dropInitialKeys: dropInitialKeys, + } +} + +func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { + dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad) + if err == nil && !o.dropped { + o.dropInitialKeys() + o.dropped = true + } + return dec, err +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go new file mode 100644 index 00000000..31d9bf0a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go @@ -0,0 +1,819 @@ +package handshake + +import ( + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/logging" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// TLS unexpected_message alert +const alertUnexpectedMessage uint8 = 10 + +type messageType uint8 + +// TLS handshake message types. +const ( + typeClientHello messageType = 1 + typeServerHello messageType = 2 + typeNewSessionTicket messageType = 4 + typeEncryptedExtensions messageType = 8 + typeCertificate messageType = 11 + typeCertificateRequest messageType = 13 + typeCertificateVerify messageType = 15 + typeFinished messageType = 20 +) + +func (m messageType) String() string { + switch m { + case typeClientHello: + return "ClientHello" + case typeServerHello: + return "ServerHello" + case typeNewSessionTicket: + return "NewSessionTicket" + case typeEncryptedExtensions: + return "EncryptedExtensions" + case typeCertificate: + return "Certificate" + case typeCertificateRequest: + return "CertificateRequest" + case typeCertificateVerify: + return "CertificateVerify" + case typeFinished: + return "Finished" + default: + return fmt.Sprintf("unknown message type: %d", m) + } +} + +const clientSessionStateRevision = 3 + +type conn struct { + localAddr, remoteAddr net.Addr + version protocol.VersionNumber +} + +var _ ConnWithVersion = &conn{} + +func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion { + return &conn{ + localAddr: local, + remoteAddr: remote, + version: version, + } +} + +var _ net.Conn = &conn{} + +func (c *conn) Read([]byte) (int, error) { return 0, nil } +func (c *conn) Write([]byte) (int, error) { return 0, nil } +func (c *conn) Close() error { return nil } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) SetReadDeadline(time.Time) error { return nil } +func (c *conn) SetWriteDeadline(time.Time) error { return nil } +func (c *conn) SetDeadline(time.Time) error { return nil } +func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version } + +type cryptoSetup struct { + tlsConf *tls.Config + extraConf *qtls.ExtraConfig + conn *qtls.Conn + + version protocol.VersionNumber + + messageChan chan []byte + isReadingHandshakeMessage chan struct{} + readFirstHandshakeMessage bool + + ourParams *wire.TransportParameters + peerParams *wire.TransportParameters + paramsChan <-chan []byte + + runner handshakeRunner + + alertChan chan uint8 + // handshakeDone is closed as soon as the go routine running qtls.Handshake() returns + handshakeDone chan struct{} + // is closed when Close() is called + closeChan chan struct{} + + zeroRTTParameters *wire.TransportParameters + clientHelloWritten bool + clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written + zeroRTTParametersChan chan<- *wire.TransportParameters + + rttStats *utils.RTTStats + + tracer logging.ConnectionTracer + logger utils.Logger + + perspective protocol.Perspective + + mutex sync.Mutex // protects all members below + + handshakeCompleteTime time.Time + + readEncLevel protocol.EncryptionLevel + writeEncLevel protocol.EncryptionLevel + + zeroRTTOpener LongHeaderOpener // only set for the server + zeroRTTSealer LongHeaderSealer // only set for the client + + initialStream io.Writer + initialOpener LongHeaderOpener + initialSealer LongHeaderSealer + + handshakeStream io.Writer + handshakeOpener LongHeaderOpener + handshakeSealer LongHeaderSealer + + aead *updatableAEAD + has1RTTSealer bool + has1RTTOpener bool +} + +var ( + _ qtls.RecordLayer = &cryptoSetup{} + _ CryptoSetup = &cryptoSetup{} +) + +// NewCryptoSetupClient creates a new crypto setup for the client +func NewCryptoSetupClient( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + localAddr net.Addr, + remoteAddr net.Addr, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + version protocol.VersionNumber, +) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { + cs, clientHelloWritten := newCryptoSetup( + initialStream, + handshakeStream, + connID, + tp, + runner, + tlsConf, + enable0RTT, + rttStats, + tracer, + logger, + protocol.PerspectiveClient, + version, + ) + cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + return cs, clientHelloWritten +} + +// NewCryptoSetupServer creates a new crypto setup for the server +func NewCryptoSetupServer( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + localAddr net.Addr, + remoteAddr net.Addr, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + version protocol.VersionNumber, +) CryptoSetup { + cs, _ := newCryptoSetup( + initialStream, + handshakeStream, + connID, + tp, + runner, + tlsConf, + enable0RTT, + rttStats, + tracer, + logger, + protocol.PerspectiveServer, + version, + ) + cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + return cs +} + +func newCryptoSetup( + initialStream io.Writer, + handshakeStream io.Writer, + connID protocol.ConnectionID, + tp *wire.TransportParameters, + runner handshakeRunner, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + perspective protocol.Perspective, + version protocol.VersionNumber, +) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) { + initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) + if tracer != nil { + tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) + tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) + } + extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version) + zeroRTTParametersChan := make(chan *wire.TransportParameters, 1) + cs := &cryptoSetup{ + tlsConf: tlsConf, + initialStream: initialStream, + initialSealer: initialSealer, + initialOpener: initialOpener, + handshakeStream: handshakeStream, + aead: newUpdatableAEAD(rttStats, tracer, logger, version), + readEncLevel: protocol.EncryptionInitial, + writeEncLevel: protocol.EncryptionInitial, + runner: runner, + ourParams: tp, + paramsChan: extHandler.TransportParameters(), + rttStats: rttStats, + tracer: tracer, + logger: logger, + perspective: perspective, + handshakeDone: make(chan struct{}), + alertChan: make(chan uint8), + clientHelloWrittenChan: make(chan struct{}), + zeroRTTParametersChan: zeroRTTParametersChan, + messageChan: make(chan []byte, 100), + isReadingHandshakeMessage: make(chan struct{}), + closeChan: make(chan struct{}), + version: version, + } + var maxEarlyData uint32 + if enable0RTT { + maxEarlyData = 0xffffffff + } + cs.extraConf = &qtls.ExtraConfig{ + GetExtensions: extHandler.GetExtensions, + ReceivedExtensions: extHandler.ReceivedExtensions, + AlternativeRecordLayer: cs, + EnforceNextProtoSelection: true, + MaxEarlyData: maxEarlyData, + Accept0RTT: cs.accept0RTT, + Rejected0RTT: cs.rejected0RTT, + Enable0RTT: enable0RTT, + GetAppDataForSessionState: cs.marshalDataForSessionState, + SetAppDataFromSessionState: cs.handleDataFromSessionState, + } + return cs, zeroRTTParametersChan +} + +func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { + initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) + h.initialSealer = initialSealer + h.initialOpener = initialOpener + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) + h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) + } +} + +func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error { + return h.aead.SetLargestAcked(pn) +} + +func (h *cryptoSetup) RunHandshake() { + // Handle errors that might occur when HandleData() is called. + handshakeComplete := make(chan struct{}) + handshakeErrChan := make(chan error, 1) + go func() { + defer close(h.handshakeDone) + if err := h.conn.Handshake(); err != nil { + handshakeErrChan <- err + return + } + close(handshakeComplete) + }() + + if h.perspective == protocol.PerspectiveClient { + select { + case err := <-handshakeErrChan: + h.onError(0, err.Error()) + return + case <-h.clientHelloWrittenChan: + } + } + + select { + case <-handshakeComplete: // return when the handshake is done + h.mutex.Lock() + h.handshakeCompleteTime = time.Now() + h.mutex.Unlock() + h.runner.OnHandshakeComplete() + case <-h.closeChan: + // wait until the Handshake() go routine has returned + <-h.handshakeDone + case alert := <-h.alertChan: + handshakeErr := <-handshakeErrChan + h.onError(alert, handshakeErr.Error()) + } +} + +func (h *cryptoSetup) onError(alert uint8, message string) { + var err error + if alert == 0 { + err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message} + } else { + err = qerr.NewCryptoError(alert, message) + } + h.runner.OnError(err) +} + +// Close closes the crypto setup. +// It aborts the handshake, if it is still running. +// It must only be called once. +func (h *cryptoSetup) Close() error { + close(h.closeChan) + // wait until qtls.Handshake() actually returned + <-h.handshakeDone + return nil +} + +// handleMessage handles a TLS handshake message. +// It is called by the crypto streams when a new message is available. +// It returns if it is done with messages on the same encryption level. +func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { + msgType := messageType(data[0]) + h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) + if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { + h.onError(alertUnexpectedMessage, err.Error()) + return false + } + h.messageChan <- data + if encLevel == protocol.Encryption1RTT { + h.handlePostHandshakeMessage() + return false + } +readLoop: + for { + select { + case data := <-h.paramsChan: + if data == nil { + h.onError(0x6d, "missing quic_transport_parameters extension") + } else { + h.handleTransportParameters(data) + } + case <-h.isReadingHandshakeMessage: + break readLoop + case <-h.handshakeDone: + break readLoop + case <-h.closeChan: + break readLoop + } + } + // We're done with the Initial encryption level after processing a ClientHello / ServerHello, + // but only if a handshake opener and sealer was created. + // Otherwise, a HelloRetryRequest was performed. + // We're done with the Handshake encryption level after processing the Finished message. + return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) || + msgType == typeFinished +} + +func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { + var expected protocol.EncryptionLevel + switch msgType { + case typeClientHello, + typeServerHello: + expected = protocol.EncryptionInitial + case typeEncryptedExtensions, + typeCertificate, + typeCertificateRequest, + typeCertificateVerify, + typeFinished: + expected = protocol.EncryptionHandshake + case typeNewSessionTicket: + expected = protocol.Encryption1RTT + default: + return fmt.Errorf("unexpected handshake message: %d", msgType) + } + if encLevel != expected { + return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) + } + return nil +} + +func (h *cryptoSetup) handleTransportParameters(data []byte) { + var tp wire.TransportParameters + if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { + h.runner.OnError(&qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + }) + } + h.peerParams = &tp + h.runner.OnReceivedParams(h.peerParams) +} + +// must be called after receiving the transport parameters +func (h *cryptoSetup) marshalDataForSessionState() []byte { + buf := &bytes.Buffer{} + quicvarint.Write(buf, clientSessionStateRevision) + quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds())) + h.peerParams.MarshalForSessionTicket(buf) + return buf.Bytes() +} + +func (h *cryptoSetup) handleDataFromSessionState(data []byte) { + tp, err := h.handleDataFromSessionStateImpl(data) + if err != nil { + h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) + return + } + h.zeroRTTParameters = tp +} + +func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { + r := bytes.NewReader(data) + ver, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if ver != clientSessionStateRevision { + return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) + } + rtt, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return nil, err + } + return &tp, nil +} + +// only valid for the server +func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { + var appData []byte + // Save transport parameters to the session ticket if we're allowing 0-RTT. + if h.extraConf.MaxEarlyData > 0 { + appData = (&sessionTicket{ + Parameters: h.ourParams, + RTT: h.rttStats.SmoothedRTT(), + }).Marshal() + } + return h.conn.GetSessionTicket(appData) +} + +// accept0RTT is called for the server when receiving the client's session ticket. +// It decides whether to accept 0-RTT. +func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { + var t sessionTicket + if err := t.Unmarshal(sessionTicketData); err != nil { + h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) + return false + } + valid := h.ourParams.ValidFor0RTT(t.Parameters) + if valid { + h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) + h.rttStats.SetInitialRTT(t.RTT) + } else { + h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") + } + return valid +} + +// rejected0RTT is called for the client when the server rejects 0-RTT. +func (h *cryptoSetup) rejected0RTT() { + h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") + + h.mutex.Lock() + had0RTTKeys := h.zeroRTTSealer != nil + h.zeroRTTSealer = nil + h.mutex.Unlock() + + if had0RTTKeys { + h.runner.DropKeys(protocol.Encryption0RTT) + } +} + +func (h *cryptoSetup) handlePostHandshakeMessage() { + // make sure the handshake has already completed + <-h.handshakeDone + + done := make(chan struct{}) + defer close(done) + + // h.alertChan is an unbuffered channel. + // If an error occurs during conn.HandlePostHandshakeMessage, + // it will be sent on this channel. + // Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock. + alertChan := make(chan uint8, 1) + go func() { + <-h.isReadingHandshakeMessage + select { + case alert := <-h.alertChan: + alertChan <- alert + case <-done: + } + }() + + if err := h.conn.HandlePostHandshakeMessage(); err != nil { + select { + case <-h.closeChan: + case alert := <-alertChan: + h.onError(alert, err.Error()) + } + } +} + +// ReadHandshakeMessage is called by TLS. +// It blocks until a new handshake message is available. +func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { + if !h.readFirstHandshakeMessage { + h.readFirstHandshakeMessage = true + } else { + select { + case h.isReadingHandshakeMessage <- struct{}{}: + case <-h.closeChan: + return nil, errors.New("error while handling the handshake message") + } + } + select { + case msg := <-h.messageChan: + return msg, nil + case <-h.closeChan: + return nil, errors.New("error while handling the handshake message") + } +} + +func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + h.mutex.Lock() + switch encLevel { + case qtls.Encryption0RTT: + if h.perspective == protocol.PerspectiveClient { + panic("Received 0-RTT read key for the client") + } + h.zeroRTTOpener = newLongHeaderOpener( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + ) + h.mutex.Unlock() + h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite()) + } + return + case qtls.EncryptionHandshake: + h.readEncLevel = protocol.EncryptionHandshake + h.handshakeOpener = newHandshakeOpener( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + h.dropInitialKeys, + h.perspective, + ) + h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + case qtls.EncryptionApplication: + h.readEncLevel = protocol.Encryption1RTT + h.aead.SetReadKey(suite, trafficSecret) + h.has1RTTOpener = true + h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + default: + panic("unexpected read encryption level") + } + h.mutex.Unlock() + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite()) + } +} + +func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + h.mutex.Lock() + switch encLevel { + case qtls.Encryption0RTT: + if h.perspective == protocol.PerspectiveServer { + panic("Received 0-RTT write key for the server") + } + h.zeroRTTSealer = newLongHeaderSealer( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + ) + h.mutex.Unlock() + h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) + } + return + case qtls.EncryptionHandshake: + h.writeEncLevel = protocol.EncryptionHandshake + h.handshakeSealer = newHandshakeSealer( + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), + h.dropInitialKeys, + h.perspective, + ) + h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + case qtls.EncryptionApplication: + h.writeEncLevel = protocol.Encryption1RTT + h.aead.SetWriteKey(suite, trafficSecret) + h.has1RTTSealer = true + h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.zeroRTTSealer != nil { + h.zeroRTTSealer = nil + h.logger.Debugf("Dropping 0-RTT keys.") + if h.tracer != nil { + h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) + } + } + default: + panic("unexpected write encryption level") + } + h.mutex.Unlock() + if h.tracer != nil { + h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective) + } +} + +// WriteRecord is called when TLS writes data +func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + //nolint:exhaustive // LS records can only be written for Initial and Handshake. + switch h.writeEncLevel { + case protocol.EncryptionInitial: + // assume that the first WriteRecord call contains the ClientHello + n, err := h.initialStream.Write(p) + if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { + h.clientHelloWritten = true + close(h.clientHelloWrittenChan) + if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { + h.logger.Debugf("Doing 0-RTT.") + h.zeroRTTParametersChan <- h.zeroRTTParameters + } else { + h.logger.Debugf("Not doing 0-RTT.") + h.zeroRTTParametersChan <- nil + } + } + return n, err + case protocol.EncryptionHandshake: + return h.handshakeStream.Write(p) + default: + panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel)) + } +} + +func (h *cryptoSetup) SendAlert(alert uint8) { + select { + case h.alertChan <- alert: + case <-h.closeChan: + // no need to send an alert when we've already closed + } +} + +// used a callback in the handshakeSealer and handshakeOpener +func (h *cryptoSetup) dropInitialKeys() { + h.mutex.Lock() + h.initialOpener = nil + h.initialSealer = nil + h.mutex.Unlock() + h.runner.DropKeys(protocol.EncryptionInitial) + h.logger.Debugf("Dropping Initial keys.") +} + +func (h *cryptoSetup) SetHandshakeConfirmed() { + h.aead.SetHandshakeConfirmed() + // drop Handshake keys + var dropped bool + h.mutex.Lock() + if h.handshakeOpener != nil { + h.handshakeOpener = nil + h.handshakeSealer = nil + dropped = true + } + h.mutex.Unlock() + if dropped { + h.runner.DropKeys(protocol.EncryptionHandshake) + h.logger.Debugf("Dropping Handshake keys.") + } +} + +func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.initialSealer == nil { + return nil, ErrKeysDropped + } + return h.initialSealer, nil +} + +func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTSealer == nil { + return nil, ErrKeysDropped + } + return h.zeroRTTSealer, nil +} + +func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.handshakeSealer == nil { + if h.initialSealer == nil { + return nil, ErrKeysDropped + } + return nil, ErrKeysNotYetAvailable + } + return h.handshakeSealer, nil +} + +func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if !h.has1RTTSealer { + return nil, ErrKeysNotYetAvailable + } + return h.aead, nil +} + +func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.initialOpener == nil { + return nil, ErrKeysDropped + } + return h.initialOpener, nil +} + +func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTOpener == nil { + if h.initialOpener != nil { + return nil, ErrKeysNotYetAvailable + } + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped + } + return h.zeroRTTOpener, nil +} + +func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.handshakeOpener == nil { + if h.initialOpener != nil { + return nil, ErrKeysNotYetAvailable + } + // if the initial opener is also not available, the keys were already dropped + return nil, ErrKeysDropped + } + return h.handshakeOpener, nil +} + +func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { + h.mutex.Lock() + defer h.mutex.Unlock() + + if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { + h.zeroRTTOpener = nil + h.logger.Debugf("Dropping 0-RTT keys.") + if h.tracer != nil { + h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) + } + } + + if !h.has1RTTOpener { + return nil, ErrKeysNotYetAvailable + } + return h.aead, nil +} + +func (h *cryptoSetup) ConnectionState() ConnectionState { + return qtls.GetConnectionState(h.conn) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/header_protector.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/header_protector.go new file mode 100644 index 00000000..1f800c50 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/header_protector.go @@ -0,0 +1,136 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/tls" + "encoding/binary" + "fmt" + + "golang.org/x/crypto/chacha20" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qtls" +) + +type headerProtector interface { + EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) + DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) +} + +func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string { + if v == protocol.Version2 { + return "quicv2 hp" + } + return "quic hp" +} + +func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { + hkdfLabel := hkdfHeaderProtectionLabel(v) + switch suite.ID { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) + case tls.TLS_CHACHA20_POLY1305_SHA256: + return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) + default: + panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) + } +} + +type aesHeaderProtector struct { + mask []byte + block cipher.Block + isLongHeader bool +} + +var _ headerProtector = &aesHeaderProtector{} + +func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { + hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) + block, err := aes.NewCipher(hpKey) + if err != nil { + panic(fmt.Sprintf("error creating new AES cipher: %s", err)) + } + return &aesHeaderProtector{ + block: block, + mask: make([]byte, block.BlockSize()), + isLongHeader: isLongHeader, + } +} + +func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) != len(p.mask) { + panic("invalid sample size") + } + p.block.Encrypt(p.mask, sample) + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} + +type chachaHeaderProtector struct { + mask [5]byte + + key [32]byte + isLongHeader bool +} + +var _ headerProtector = &chachaHeaderProtector{} + +func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { + hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) + + p := &chachaHeaderProtector{ + isLongHeader: isLongHeader, + } + copy(p.key[:], hpKey) + return p +} + +func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + p.apply(sample, firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) { + if len(sample) != 16 { + panic("invalid sample size") + } + for i := 0; i < 5; i++ { + p.mask[i] = 0 + } + cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:]) + if err != nil { + panic(err) + } + cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4])) + cipher.XORKeyStream(p.mask[:], p.mask[:]) + p.applyMask(firstByte, hdrBytes) +} + +func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) { + if p.isLongHeader { + *firstByte ^= p.mask[0] & 0xf + } else { + *firstByte ^= p.mask[0] & 0x1f + } + for i := range hdrBytes { + hdrBytes[i] ^= p.mask[i+1] + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/hkdf.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/hkdf.go new file mode 100644 index 00000000..c4fd86c5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/hkdf.go @@ -0,0 +1,29 @@ +package handshake + +import ( + "crypto" + "encoding/binary" + + "golang.org/x/crypto/hkdf" +) + +// hkdfExpandLabel HKDF expands a label. +// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the +// hkdfExpandLabel in the standard library. +func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { + b := make([]byte, 3, 3+6+len(label)+1+len(context)) + binary.BigEndian.PutUint16(b, uint16(length)) + b[2] = uint8(6 + len(label)) + b = append(b, []byte("tls13 ")...) + b = append(b, []byte(label)...) + b = b[:3+6+len(label)+1] + b[3+6+len(label)] = uint8(len(context)) + b = append(b, context...) + + out := make([]byte, length) + n, err := hkdf.Expand(hash.New, secret, b).Read(out) + if err != nil || n != length { + panic("quic: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/initial_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/initial_aead.go new file mode 100644 index 00000000..00ed243c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/initial_aead.go @@ -0,0 +1,81 @@ +package handshake + +import ( + "crypto" + "crypto/tls" + + "golang.org/x/crypto/hkdf" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qtls" +) + +var ( + quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99} + quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + quicSaltV2 = []byte{0xa7, 0x07, 0xc2, 0x03, 0xa5, 0x9b, 0x47, 0x18, 0x4a, 0x1d, 0x62, 0xca, 0x57, 0x04, 0x06, 0xea, 0x7a, 0xe3, 0xe5, 0xd3} +) + +const ( + hkdfLabelKeyV1 = "quic key" + hkdfLabelKeyV2 = "quicv2 key" + hkdfLabelIVV1 = "quic iv" + hkdfLabelIVV2 = "quicv2 iv" +) + +func getSalt(v protocol.VersionNumber) []byte { + if v == protocol.Version2 { + return quicSaltV2 + } + if v == protocol.Version1 { + return quicSaltV1 + } + return quicSaltOld +} + +var initialSuite = &qtls.CipherSuiteTLS13{ + ID: tls.TLS_AES_128_GCM_SHA256, + KeyLen: 16, + AEAD: qtls.AEADAESGCMTLS13, + Hash: crypto.SHA256, +} + +// NewInitialAEAD creates a new AEAD for Initial encryption / decryption. +func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { + clientSecret, serverSecret := computeSecrets(connID, v) + var mySecret, otherSecret []byte + if pers == protocol.PerspectiveClient { + mySecret = clientSecret + otherSecret = serverSecret + } else { + mySecret = serverSecret + otherSecret = clientSecret + } + myKey, myIV := computeInitialKeyAndIV(mySecret, v) + otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v) + + encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) + decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) + + return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)), + newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) +} + +func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { + initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v)) + clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) + serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) + return +} + +func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) { + keyLabel := hkdfLabelKeyV1 + ivLabel := hkdfLabelIVV1 + if v == protocol.Version2 { + keyLabel = hkdfLabelKeyV2 + ivLabel = hkdfLabelIVV2 + } + key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) + iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) + return +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go new file mode 100644 index 00000000..112f6c25 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go @@ -0,0 +1,102 @@ +package handshake + +import ( + "errors" + "io" + "net" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +var ( + // ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level, + // but the corresponding opener has not yet been initialized + // This can happen when packets arrive out of order. + ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available") + // ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level, + // but the corresponding keys have already been dropped. + ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped") + // ErrDecryptionFailed is returned when the AEAD fails to open the packet. + ErrDecryptionFailed = errors.New("decryption failed") +) + +// ConnectionState contains information about the state of the connection. +type ConnectionState = qtls.ConnectionState + +type headerDecryptor interface { + DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) +} + +// LongHeaderOpener opens a long header packet +type LongHeaderOpener interface { + headerDecryptor + DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber + Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error) +} + +// ShortHeaderOpener opens a short header packet +type ShortHeaderOpener interface { + headerDecryptor + DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber + Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error) +} + +// LongHeaderSealer seals a long header packet +type LongHeaderSealer interface { + Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) + Overhead() int +} + +// ShortHeaderSealer seals a short header packet +type ShortHeaderSealer interface { + LongHeaderSealer + KeyPhase() protocol.KeyPhaseBit +} + +// A tlsExtensionHandler sends and received the QUIC TLS extension. +type tlsExtensionHandler interface { + GetExtensions(msgType uint8) []qtls.Extension + ReceivedExtensions(msgType uint8, exts []qtls.Extension) + TransportParameters() <-chan []byte +} + +type handshakeRunner interface { + OnReceivedParams(*wire.TransportParameters) + OnHandshakeComplete() + OnError(error) + DropKeys(protocol.EncryptionLevel) +} + +// CryptoSetup handles the handshake and protecting / unprotecting packets +type CryptoSetup interface { + RunHandshake() + io.Closer + ChangeConnectionID(protocol.ConnectionID) + GetSessionTicket() ([]byte, error) + + HandleMessage([]byte, protocol.EncryptionLevel) bool + SetLargest1RTTAcked(protocol.PacketNumber) error + SetHandshakeConfirmed() + ConnectionState() ConnectionState + + GetInitialOpener() (LongHeaderOpener, error) + GetHandshakeOpener() (LongHeaderOpener, error) + Get0RTTOpener() (LongHeaderOpener, error) + Get1RTTOpener() (ShortHeaderOpener, error) + + GetInitialSealer() (LongHeaderSealer, error) + GetHandshakeSealer() (LongHeaderSealer, error) + Get0RTTSealer() (LongHeaderSealer, error) + Get1RTTSealer() (ShortHeaderSealer, error) +} + +// ConnWithVersion is the connection used in the ClientHelloInfo. +// It can be used to determine the QUIC version in use. +type ConnWithVersion interface { + net.Conn + GetQUICVersion() protocol.VersionNumber +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go new file mode 100644 index 00000000..c7a8d13e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go @@ -0,0 +1,3 @@ +package handshake + +//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner" diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/retry.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/retry.go new file mode 100644 index 00000000..b7cb20c1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/retry.go @@ -0,0 +1,62 @@ +package handshake + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "fmt" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +var ( + oldRetryAEAD cipher.AEAD // used for QUIC draft versions up to 34 + retryAEAD cipher.AEAD // used for QUIC draft-34 +) + +func init() { + oldRetryAEAD = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1}) + retryAEAD = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) +} + +func initAEAD(key [16]byte) cipher.AEAD { + aes, err := aes.NewCipher(key[:]) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + return aead +} + +var ( + retryBuf bytes.Buffer + retryMutex sync.Mutex + oldRetryNonce = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c} + retryNonce = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} +) + +// GetRetryIntegrityTag calculates the integrity tag on a Retry packet +func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte { + retryMutex.Lock() + retryBuf.WriteByte(uint8(origDestConnID.Len())) + retryBuf.Write(origDestConnID.Bytes()) + retryBuf.Write(retry) + + var tag [16]byte + var sealed []byte + if version != protocol.Version1 { + sealed = oldRetryAEAD.Seal(tag[:0], oldRetryNonce[:], nil, retryBuf.Bytes()) + } else { + sealed = retryAEAD.Seal(tag[:0], retryNonce[:], nil, retryBuf.Bytes()) + } + if len(sealed) != 16 { + panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed))) + } + retryBuf.Reset() + retryMutex.Unlock() + return &tag +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/session_ticket.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/session_ticket.go new file mode 100644 index 00000000..75cc04f9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/session_ticket.go @@ -0,0 +1,48 @@ +package handshake + +import ( + "bytes" + "errors" + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +const sessionTicketRevision = 2 + +type sessionTicket struct { + Parameters *wire.TransportParameters + RTT time.Duration // to be encoded in mus +} + +func (t *sessionTicket) Marshal() []byte { + b := &bytes.Buffer{} + quicvarint.Write(b, sessionTicketRevision) + quicvarint.Write(b, uint64(t.RTT.Microseconds())) + t.Parameters.MarshalForSessionTicket(b) + return b.Bytes() +} + +func (t *sessionTicket) Unmarshal(b []byte) error { + r := bytes.NewReader(b) + rev, err := quicvarint.Read(r) + if err != nil { + return errors.New("failed to read session ticket revision") + } + if rev != sessionTicketRevision { + return fmt.Errorf("unknown session ticket revision: %d", rev) + } + rtt, err := quicvarint.Read(r) + if err != nil { + return errors.New("failed to read RTT") + } + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + } + t.Parameters = &tp + t.RTT = time.Duration(rtt) * time.Microsecond + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler.go new file mode 100644 index 00000000..3a679034 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler.go @@ -0,0 +1,68 @@ +package handshake + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qtls" +) + +const ( + quicTLSExtensionTypeOldDrafts = 0xffa5 + quicTLSExtensionType = 0x39 +) + +type extensionHandler struct { + ourParams []byte + paramsChan chan []byte + + extensionType uint16 + + perspective protocol.Perspective +} + +var _ tlsExtensionHandler = &extensionHandler{} + +// newExtensionHandler creates a new extension handler +func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler { + et := uint16(quicTLSExtensionType) + if v != protocol.Version1 { + et = quicTLSExtensionTypeOldDrafts + } + return &extensionHandler{ + ourParams: params, + paramsChan: make(chan []byte), + perspective: pers, + extensionType: et, + } +} + +func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension { + if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) || + (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) { + return nil + } + return []qtls.Extension{{ + Type: h.extensionType, + Data: h.ourParams, + }} +} + +func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) { + if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) || + (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) { + return + } + + var data []byte + for _, ext := range exts { + if ext.Type == h.extensionType { + data = ext.Data + break + } + } + + h.paramsChan <- data +} + +func (h *extensionHandler) TransportParameters() <-chan []byte { + return h.paramsChan +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_generator.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_generator.go new file mode 100644 index 00000000..2df5fcd8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_generator.go @@ -0,0 +1,134 @@ +package handshake + +import ( + "encoding/asn1" + "fmt" + "io" + "net" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +const ( + tokenPrefixIP byte = iota + tokenPrefixString +) + +// A Token is derived from the client address and can be used to verify the ownership of this address. +type Token struct { + IsRetryToken bool + RemoteAddr string + SentTime time.Time + // only set for retry tokens + OriginalDestConnectionID protocol.ConnectionID + RetrySrcConnectionID protocol.ConnectionID +} + +// token is the struct that is used for ASN1 serialization and deserialization +type token struct { + IsRetryToken bool + RemoteAddr []byte + Timestamp int64 + OriginalDestConnectionID []byte + RetrySrcConnectionID []byte +} + +// A TokenGenerator generates tokens +type TokenGenerator struct { + tokenProtector tokenProtector +} + +// NewTokenGenerator initializes a new TookenGenerator +func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { + tokenProtector, err := newTokenProtector(rand) + if err != nil { + return nil, err + } + return &TokenGenerator{ + tokenProtector: tokenProtector, + }, nil +} + +// NewRetryToken generates a new token for a Retry for a given source address +func (g *TokenGenerator) NewRetryToken( + raddr net.Addr, + origDestConnID protocol.ConnectionID, + retrySrcConnID protocol.ConnectionID, +) ([]byte, error) { + data, err := asn1.Marshal(token{ + IsRetryToken: true, + RemoteAddr: encodeRemoteAddr(raddr), + OriginalDestConnectionID: origDestConnID, + RetrySrcConnectionID: retrySrcConnID, + Timestamp: time.Now().UnixNano(), + }) + if err != nil { + return nil, err + } + return g.tokenProtector.NewToken(data) +} + +// NewToken generates a new token to be sent in a NEW_TOKEN frame +func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) { + data, err := asn1.Marshal(token{ + RemoteAddr: encodeRemoteAddr(raddr), + Timestamp: time.Now().UnixNano(), + }) + if err != nil { + return nil, err + } + return g.tokenProtector.NewToken(data) +} + +// DecodeToken decodes a token +func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { + // if the client didn't send any token, DecodeToken will be called with a nil-slice + if len(encrypted) == 0 { + return nil, nil + } + + data, err := g.tokenProtector.DecodeToken(encrypted) + if err != nil { + return nil, err + } + t := &token{} + rest, err := asn1.Unmarshal(data, t) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) + } + token := &Token{ + IsRetryToken: t.IsRetryToken, + RemoteAddr: decodeRemoteAddr(t.RemoteAddr), + SentTime: time.Unix(0, t.Timestamp), + } + if t.IsRetryToken { + token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) + token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) + } + return token, nil +} + +// encodeRemoteAddr encodes a remote address such that it can be saved in the token +func encodeRemoteAddr(remoteAddr net.Addr) []byte { + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + return append([]byte{tokenPrefixIP}, udpAddr.IP...) + } + return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) +} + +// decodeRemoteAddr decodes the remote address saved in the token +func decodeRemoteAddr(data []byte) string { + // data will never be empty for a token that we generated. + // Check it to be on the safe side + if len(data) == 0 { + return "" + } + if data[0] == tokenPrefixIP { + return net.IP(data[1:]).String() + } + return string(data[1:]) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_protector.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_protector.go new file mode 100644 index 00000000..650f230b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_protector.go @@ -0,0 +1,89 @@ +package handshake + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/crypto/hkdf" +) + +// TokenProtector is used to create and verify a token +type tokenProtector interface { + // NewToken creates a new token + NewToken([]byte) ([]byte, error) + // DecodeToken decodes a token + DecodeToken([]byte) ([]byte, error) +} + +const ( + tokenSecretSize = 32 + tokenNonceSize = 32 +) + +// tokenProtector is used to create and verify a token +type tokenProtectorImpl struct { + rand io.Reader + secret []byte +} + +// newTokenProtector creates a source for source address tokens +func newTokenProtector(rand io.Reader) (tokenProtector, error) { + secret := make([]byte, tokenSecretSize) + if _, err := rand.Read(secret); err != nil { + return nil, err + } + return &tokenProtectorImpl{ + rand: rand, + secret: secret, + }, nil +} + +// NewToken encodes data into a new token. +func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { + nonce := make([]byte, tokenNonceSize) + if _, err := s.rand.Read(nonce); err != nil { + return nil, err + } + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil +} + +// DecodeToken decodes a token. +func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { + if len(p) < tokenNonceSize { + return nil, fmt.Errorf("token too short: %d", len(p)) + } + nonce := p[:tokenNonceSize] + aead, aeadNonce, err := s.createAEAD(nonce) + if err != nil { + return nil, err + } + return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil) +} + +func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { + h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) + key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 + if _, err := io.ReadFull(h, key); err != nil { + return nil, nil, err + } + aeadNonce := make([]byte, 12) + if _, err := io.ReadFull(h, aeadNonce); err != nil { + return nil, nil, err + } + c, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, nil, err + } + return aead, aeadNonce, nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/updatable_aead.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/updatable_aead.go new file mode 100644 index 00000000..1532e7b5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/updatable_aead.go @@ -0,0 +1,323 @@ +package handshake + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "encoding/binary" + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/logging" +) + +// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. +// It's a package-level variable to allow modifying it for testing purposes. +var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval + +type updatableAEAD struct { + suite *qtls.CipherSuiteTLS13 + + keyPhase protocol.KeyPhase + largestAcked protocol.PacketNumber + firstPacketNumber protocol.PacketNumber + handshakeConfirmed bool + + keyUpdateInterval uint64 + invalidPacketLimit uint64 + invalidPacketCount uint64 + + // Time when the keys should be dropped. Keys are dropped on the next call to Open(). + prevRcvAEADExpiry time.Time + prevRcvAEAD cipher.AEAD + + firstRcvdWithCurrentKey protocol.PacketNumber + firstSentWithCurrentKey protocol.PacketNumber + highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) + numRcvdWithCurrentKey uint64 + numSentWithCurrentKey uint64 + rcvAEAD cipher.AEAD + sendAEAD cipher.AEAD + // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). + aeadOverhead int + + nextRcvAEAD cipher.AEAD + nextSendAEAD cipher.AEAD + nextRcvTrafficSecret []byte + nextSendTrafficSecret []byte + + headerDecrypter headerProtector + headerEncrypter headerProtector + + rttStats *utils.RTTStats + + tracer logging.ConnectionTracer + logger utils.Logger + version protocol.VersionNumber + + // use a single slice to avoid allocations + nonceBuf []byte +} + +var ( + _ ShortHeaderOpener = &updatableAEAD{} + _ ShortHeaderSealer = &updatableAEAD{} +) + +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { + return &updatableAEAD{ + firstPacketNumber: protocol.InvalidPacketNumber, + largestAcked: protocol.InvalidPacketNumber, + firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, + firstSentWithCurrentKey: protocol.InvalidPacketNumber, + keyUpdateInterval: KeyUpdateInterval, + rttStats: rttStats, + tracer: tracer, + logger: logger, + version: version, + } +} + +func (a *updatableAEAD) rollKeys() { + if a.prevRcvAEAD != nil { + a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) + if a.tracer != nil { + a.tracer.DroppedKey(a.keyPhase - 1) + } + a.prevRcvAEADExpiry = time.Time{} + } + + a.keyPhase++ + a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber + a.firstSentWithCurrentKey = protocol.InvalidPacketNumber + a.numRcvdWithCurrentKey = 0 + a.numSentWithCurrentKey = 0 + a.prevRcvAEAD = a.rcvAEAD + a.rcvAEAD = a.nextRcvAEAD + a.sendAEAD = a.nextSendAEAD + + a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) + a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) + a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) + a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) +} + +func (a *updatableAEAD) startKeyDropTimer(now time.Time) { + d := 3 * a.rttStats.PTO(true) + a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) + a.prevRcvAEADExpiry = now.Add(d) +} + +func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { + return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) +} + +// For the client, this function is called before SetWriteKey. +// For the server, this function is called after SetWriteKey. +func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) + a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) + if a.suite == nil { + a.setAEADParameters(a.rcvAEAD, suite) + } + + a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) + a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) +} + +// For the client, this function is called after SetReadKey. +// For the server, this function is called before SetWriteKey. +func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { + a.sendAEAD = createAEAD(suite, trafficSecret, a.version) + a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) + if a.suite == nil { + a.setAEADParameters(a.sendAEAD, suite) + } + + a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) + a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) +} + +func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { + a.nonceBuf = make([]byte, aead.NonceSize()) + a.aeadOverhead = aead.Overhead() + a.suite = suite + switch suite.ID { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + a.invalidPacketLimit = protocol.InvalidPacketLimitAES + case tls.TLS_CHACHA20_POLY1305_SHA256: + a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha + default: + panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) + } +} + +func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { + return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) +} + +func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { + dec, err := a.open(dst, src, rcvTime, pn, kp, ad) + if err == ErrDecryptionFailed { + a.invalidPacketCount++ + if a.invalidPacketCount >= a.invalidPacketLimit { + return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} + } + } + if err == nil { + a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn) + } + return dec, err +} + +func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { + if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { + a.prevRcvAEAD = nil + a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) + a.prevRcvAEADExpiry = time.Time{} + if a.tracer != nil { + a.tracer.DroppedKey(a.keyPhase - 1) + } + } + binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + if kp != a.keyPhase.Bit() { + if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { + if a.prevRcvAEAD == nil { + return nil, ErrKeysDropped + } + // we updated the key, but the peer hasn't updated yet + dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + err = ErrDecryptionFailed + } + return dec, err + } + // try opening the packet with the next key phase + dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + return nil, ErrDecryptionFailed + } + // Opening succeeded. Check if the peer was allowed to update. + if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + return nil, &qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "keys updated too quickly", + } + } + a.rollKeys() + a.logger.Debugf("Peer updated keys to %d", a.keyPhase) + // The peer initiated this key update. It's safe to drop the keys for the previous generation now. + // Start a timer to drop the previous key generation. + a.startKeyDropTimer(rcvTime) + if a.tracer != nil { + a.tracer.UpdatedKey(a.keyPhase, true) + } + a.firstRcvdWithCurrentKey = pn + return dec, err + } + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) + if err != nil { + return dec, ErrDecryptionFailed + } + a.numRcvdWithCurrentKey++ + if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { + // We initiated the key updated, and now we received the first packet protected with the new key phase. + // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. + if a.keyPhase > 0 { + a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) + a.startKeyDropTimer(rcvTime) + } + a.firstRcvdWithCurrentKey = pn + } + return dec, err +} + +func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { + if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { + a.firstSentWithCurrentKey = pn + } + if a.firstPacketNumber == protocol.InvalidPacketNumber { + a.firstPacketNumber = pn + } + a.numSentWithCurrentKey++ + binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) + // The AEAD we're using here will be the qtls.aeadAESGCM13. + // It uses the nonce provided here and XOR it with the IV. + return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) +} + +func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { + if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { + return &qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), + } + } + a.largestAcked = pn + return nil +} + +func (a *updatableAEAD) SetHandshakeConfirmed() { + a.handshakeConfirmed = true +} + +func (a *updatableAEAD) updateAllowed() bool { + if !a.handshakeConfirmed { + return false + } + // the first key update is allowed as soon as the handshake is confirmed + return a.keyPhase == 0 || + // subsequent key updates as soon as a packet sent with that key phase has been acknowledged + (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && + a.largestAcked != protocol.InvalidPacketNumber && + a.largestAcked >= a.firstSentWithCurrentKey) +} + +func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { + if !a.updateAllowed() { + return false + } + if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { + a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) + return true + } + if a.numSentWithCurrentKey >= a.keyUpdateInterval { + a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) + return true + } + return false +} + +func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { + if a.shouldInitiateKeyUpdate() { + a.rollKeys() + a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) + if a.tracer != nil { + a.tracer.UpdatedKey(a.keyPhase, false) + } + } + return a.keyPhase.Bit() +} + +func (a *updatableAEAD) Overhead() int { + return a.aeadOverhead +} + +func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) +} + +func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { + a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) +} + +func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { + return a.firstPacketNumber +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/logutils/frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/logutils/frame.go new file mode 100644 index 00000000..6e0fd311 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/logutils/frame.go @@ -0,0 +1,33 @@ +package logutils + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/logging" +) + +// ConvertFrame converts a wire.Frame into a logging.Frame. +// This makes it possible for external packages to access the frames. +// Furthermore, it removes the data slices from CRYPTO and STREAM frames. +func ConvertFrame(frame wire.Frame) logging.Frame { + switch f := frame.(type) { + case *wire.CryptoFrame: + return &logging.CryptoFrame{ + Offset: f.Offset, + Length: protocol.ByteCount(len(f.Data)), + } + case *wire.StreamFrame: + return &logging.StreamFrame{ + StreamID: f.StreamID, + Offset: f.Offset, + Length: f.DataLen(), + Fin: f.Fin, + } + case *wire.DatagramFrame: + return &logging.DatagramFrame{ + Length: logging.ByteCount(len(f.Data)), + } + default: + return logging.Frame(frame) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go new file mode 100644 index 00000000..3aec2cd3 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go @@ -0,0 +1,69 @@ +package protocol + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" +) + +// A ConnectionID in QUIC +type ConnectionID []byte + +const maxConnectionIDLen = 20 + +// GenerateConnectionID generates a connection ID using cryptographic random +func GenerateConnectionID(len int) (ConnectionID, error) { + b := make([]byte, len) + if _, err := rand.Read(b); err != nil { + return nil, err + } + return ConnectionID(b), nil +} + +// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. +// It uses a length randomly chosen between 8 and 20 bytes. +func GenerateConnectionIDForInitial() (ConnectionID, error) { + r := make([]byte, 1) + if _, err := rand.Read(r); err != nil { + return nil, err + } + len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(len) +} + +// ReadConnectionID reads a connection ID of length len from the given io.Reader. +// It returns io.EOF if there are not enough bytes to read. +func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { + if len == 0 { + return nil, nil + } + c := make(ConnectionID, len) + _, err := io.ReadFull(r, c) + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return c, err +} + +// Equal says if two connection IDs are equal +func (c ConnectionID) Equal(other ConnectionID) bool { + return bytes.Equal(c, other) +} + +// Len returns the length of the connection ID in bytes +func (c ConnectionID) Len() int { + return len(c) +} + +// Bytes returns the byte representation +func (c ConnectionID) Bytes() []byte { + return []byte(c) +} + +func (c ConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%x", c.Bytes()) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go new file mode 100644 index 00000000..32d38ab1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go @@ -0,0 +1,30 @@ +package protocol + +// EncryptionLevel is the encryption level +// Default value is Unencrypted +type EncryptionLevel uint8 + +const ( + // EncryptionInitial is the Initial encryption level + EncryptionInitial EncryptionLevel = 1 + iota + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT + // Encryption1RTT is the 1-RTT encryption level + Encryption1RTT +) + +func (e EncryptionLevel) String() string { + switch e { + case EncryptionInitial: + return "Initial" + case EncryptionHandshake: + return "Handshake" + case Encryption0RTT: + return "0-RTT" + case Encryption1RTT: + return "1-RTT" + } + return "unknown" +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/key_phase.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/key_phase.go new file mode 100644 index 00000000..edd740cf --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/key_phase.go @@ -0,0 +1,36 @@ +package protocol + +// KeyPhase is the key phase +type KeyPhase uint64 + +// Bit determines the key phase bit +func (p KeyPhase) Bit() KeyPhaseBit { + if p%2 == 0 { + return KeyPhaseZero + } + return KeyPhaseOne +} + +// KeyPhaseBit is the key phase bit +type KeyPhaseBit uint8 + +const ( + // KeyPhaseUndefined is an undefined key phase + KeyPhaseUndefined KeyPhaseBit = iota + // KeyPhaseZero is key phase 0 + KeyPhaseZero + // KeyPhaseOne is key phase 1 + KeyPhaseOne +) + +func (p KeyPhaseBit) String() string { + //nolint:exhaustive + switch p { + case KeyPhaseZero: + return "0" + case KeyPhaseOne: + return "1" + default: + return "undefined" + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go new file mode 100644 index 00000000..bd340161 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go @@ -0,0 +1,79 @@ +package protocol + +// A PacketNumber in QUIC +type PacketNumber int64 + +// InvalidPacketNumber is a packet number that is never sent. +// In QUIC, 0 is a valid packet number. +const InvalidPacketNumber PacketNumber = -1 + +// PacketNumberLen is the length of the packet number in bytes +type PacketNumberLen uint8 + +const ( + // PacketNumberLen1 is a packet number length of 1 byte + PacketNumberLen1 PacketNumberLen = 1 + // PacketNumberLen2 is a packet number length of 2 bytes + PacketNumberLen2 PacketNumberLen = 2 + // PacketNumberLen3 is a packet number length of 3 bytes + PacketNumberLen3 PacketNumberLen = 3 + // PacketNumberLen4 is a packet number length of 4 bytes + PacketNumberLen4 PacketNumberLen = 4 +) + +// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number +func DecodePacketNumber( + packetNumberLength PacketNumberLen, + lastPacketNumber PacketNumber, + wirePacketNumber PacketNumber, +) PacketNumber { + var epochDelta PacketNumber + switch packetNumberLength { + case PacketNumberLen1: + epochDelta = PacketNumber(1) << 8 + case PacketNumberLen2: + epochDelta = PacketNumber(1) << 16 + case PacketNumberLen3: + epochDelta = PacketNumber(1) << 24 + case PacketNumberLen4: + epochDelta = PacketNumber(1) << 32 + } + epoch := lastPacketNumber & ^(epochDelta - 1) + var prevEpochBegin PacketNumber + if epoch > epochDelta { + prevEpochBegin = epoch - epochDelta + } + nextEpochBegin := epoch + epochDelta + return closestTo( + lastPacketNumber+1, + epoch+wirePacketNumber, + closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber), + ) +} + +func closestTo(target, a, b PacketNumber) PacketNumber { + if delta(target, a) < delta(target, b) { + return a + } + return b +} + +func delta(a, b PacketNumber) PacketNumber { + if a < b { + return b - a + } + return a - b +} + +// GetPacketNumberLengthForHeader gets the length of the packet number for the public header +// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances +func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen { + diff := uint64(packetNumber - leastUnacked) + if diff < (1 << (16 - 1)) { + return PacketNumberLen2 + } + if diff < (1 << (24 - 1)) { + return PacketNumberLen3 + } + return PacketNumberLen4 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/params.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/params.go new file mode 100644 index 00000000..83137113 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/params.go @@ -0,0 +1,193 @@ +package protocol + +import "time" + +// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. +const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB + +// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. +const InitialPacketSizeIPv4 = 1252 + +// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets. +const InitialPacketSizeIPv6 = 1232 + +// MaxCongestionWindowPackets is the maximum congestion window in packet. +const MaxCongestionWindowPackets = 10000 + +// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection. +const MaxUndecryptablePackets = 32 + +// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window +// This is the value that Chromium is using +const ConnectionFlowControlMultiplier = 1.5 + +// DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data +const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb + +// DefaultInitialMaxData is the connection-level flow control window for receiving data +const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData + +// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data +const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB + +// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data +const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB + +// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client +const WindowUpdateThreshold = 0.25 + +// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open +const DefaultMaxIncomingStreams = 100 + +// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open +const DefaultMaxIncomingUniStreams = 100 + +// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed. +const MaxServerUnprocessedPackets = 1024 + +// MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed. +const MaxConnUnprocessedPackets = 256 + +// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. +// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. +const SkipPacketInitialPeriod PacketNumber = 256 + +// SkipPacketMaxPeriod is the maximum period length used for packet number skipping. +const SkipPacketMaxPeriod PacketNumber = 128 * 1024 + +// MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting. +// If the queue is full, new connection attempts will be rejected. +const MaxAcceptQueueSize = 32 + +// TokenValidity is the duration that a (non-retry) token is considered valid +const TokenValidity = 24 * time.Hour + +// RetryTokenValidity is the duration that a retry token is considered valid +const RetryTokenValidity = 10 * time.Second + +// MaxOutstandingSentPackets is maximum number of packets saved for retransmission. +// When reached, it imposes a soft limit on sending new packets: +// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. +const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets + +// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission. +// When reached, no more packets will be sent. +// This value *must* be larger than MaxOutstandingSentPackets. +const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4 + +// MaxNonAckElicitingAcks is the maximum number of packets containing an ACK, +// but no ack-eliciting frames, that we send in a row +const MaxNonAckElicitingAcks = 19 + +// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames +// prevents DoS attacks against the streamFrameSorter +const MaxStreamFrameSorterGaps = 1000 + +// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame +// that we use the buffer for. This protects against a DoS where an attacker would send us +// very small STREAM frames to consume a lot of memory. +const MinStreamFrameBufferSize = 128 + +// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack. +// If a packet has less than this number of bytes, we won't coalesce any more packets onto it. +const MinCoalescedPacketSize = 128 + +// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams. +// This limits the size of the ClientHello and Certificates that can be received. +const MaxCryptoStreamOffset = 16 * (1 << 10) + +// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout +const MinRemoteIdleTimeout = 5 * time.Second + +// DefaultIdleTimeout is the default idle timeout +const DefaultIdleTimeout = 30 * time.Second + +// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. +const DefaultHandshakeIdleTimeout = 5 * time.Second + +// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. +const DefaultHandshakeTimeout = 10 * time.Second + +// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. +// It should be shorter than the time that NATs clear their mapping. +const MaxKeepAliveInterval = 20 * time.Second + +// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE. +// after this time all information about the old connection will be deleted +const RetiredConnectionIDDeleteTimeout = 5 * time.Second + +// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame. +// This avoids splitting up STREAM frames into small pieces, which has 2 advantages: +// 1. it reduces the framing overhead +// 2. it reduces the head-of-line blocking, when a packet is lost +const MinStreamFrameSize ByteCount = 128 + +// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames +// we send after the handshake completes. +const MaxPostHandshakeCryptoFrameSize = 1000 + +// MaxAckFrameSize is the maximum size for an ACK frame that we write +// Due to the varint encoding, ACK frames can grow (almost) indefinitely large. +// The MaxAckFrameSize should be large enough to encode many ACK range, +// but must ensure that a maximum size ACK frame fits into one packet. +const MaxAckFrameSize ByteCount = 1000 + +// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). +// The size is chosen such that a DATAGRAM frame fits into a QUIC packet. +const MaxDatagramFrameSize ByteCount = 1220 + +// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) +const DatagramRcvQueueLen = 128 + +// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. +// It also serves as a limit for the packet history. +// If at any point we keep track of more ranges, old ranges are discarded. +const MaxNumAckRanges = 32 + +// MinPacingDelay is the minimum duration that is used for packet pacing +// If the packet packing frequency is higher, multiple packets might be sent at once. +// Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth. +const MinPacingDelay = time.Millisecond + +// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections +// if no other value is configured. +const DefaultConnectionIDLength = 4 + +// MaxActiveConnectionIDs is the number of connection IDs that we're storing. +const MaxActiveConnectionIDs = 4 + +// MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time. +const MaxIssuedConnectionIDs = 6 + +// PacketsPerConnectionID is the number of packets we send using one connection ID. +// If the peer provices us with enough new connection IDs, we switch to a new connection ID. +const PacketsPerConnectionID = 10000 + +// AckDelayExponent is the ack delay exponent used when sending ACKs. +const AckDelayExponent = 3 + +// Estimated timer granularity. +// The loss detection timer will not be set to a value smaller than granularity. +const TimerGranularity = time.Millisecond + +// MaxAckDelay is the maximum time by which we delay sending ACKs. +const MaxAckDelay = 25 * time.Millisecond + +// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity. +// This is the value that should be advertised to the peer. +const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity + +// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. +const KeyUpdateInterval = 100 * 1000 + +// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received. +const Max0RTTQueueingDuration = 100 * time.Millisecond + +// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for. +const Max0RTTQueues = 32 + +// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection. +// When a new connection is created, all buffered packets are passed to the connection immediately. +// To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets. +// To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets. +const Max0RTTQueueLen = 31 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go new file mode 100644 index 00000000..43358fec --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go @@ -0,0 +1,26 @@ +package protocol + +// Perspective determines if we're acting as a server or a client +type Perspective int + +// the perspectives +const ( + PerspectiveServer Perspective = 1 + PerspectiveClient Perspective = 2 +) + +// Opposite returns the perspective of the peer +func (p Perspective) Opposite() Perspective { + return 3 - p +} + +func (p Perspective) String() string { + switch p { + case PerspectiveServer: + return "Server" + case PerspectiveClient: + return "Client" + default: + return "invalid perspective" + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go new file mode 100644 index 00000000..8241e274 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go @@ -0,0 +1,97 @@ +package protocol + +import ( + "fmt" + "time" +) + +// The PacketType is the Long Header Type +type PacketType uint8 + +const ( + // PacketTypeInitial is the packet type of an Initial packet + PacketTypeInitial PacketType = 1 + iota + // PacketTypeRetry is the packet type of a Retry packet + PacketTypeRetry + // PacketTypeHandshake is the packet type of a Handshake packet + PacketTypeHandshake + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT +) + +func (t PacketType) String() string { + switch t { + case PacketTypeInitial: + return "Initial" + case PacketTypeRetry: + return "Retry" + case PacketTypeHandshake: + return "Handshake" + case PacketType0RTT: + return "0-RTT Protected" + default: + return fmt.Sprintf("unknown packet type: %d", t) + } +} + +type ECN uint8 + +const ( + ECNNon ECN = iota // 00 + ECT1 // 01 + ECT0 // 10 + ECNCE // 11 +) + +// A ByteCount in QUIC +type ByteCount int64 + +// MaxByteCount is the maximum value of a ByteCount +const MaxByteCount = ByteCount(1<<62 - 1) + +// InvalidByteCount is an invalid byte count +const InvalidByteCount ByteCount = -1 + +// A StatelessResetToken is a stateless reset token. +type StatelessResetToken [16]byte + +// MaxPacketBufferSize maximum packet size of any QUIC packet, based on +// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, +// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. +// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. +const MaxPacketBufferSize ByteCount = 1452 + +// MinInitialPacketSize is the minimum size an Initial packet is required to have. +const MinInitialPacketSize = 1200 + +// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version +// needs to have in order to trigger a Version Negotiation packet. +const MinUnknownVersionPacketSize = MinInitialPacketSize + +// MinStatelessResetSize is the minimum size of a stateless reset packet that we send +const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */ + +// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet. +const MinConnectionIDLenInitial = 8 + +// DefaultAckDelayExponent is the default ack delay exponent +const DefaultAckDelayExponent = 3 + +// MaxAckDelayExponent is the maximum ack delay exponent +const MaxAckDelayExponent = 20 + +// DefaultMaxAckDelay is the default max_ack_delay +const DefaultMaxAckDelay = 25 * time.Millisecond + +// MaxMaxAckDelay is the maximum max_ack_delay +const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond + +// MaxConnIDLen is the maximum length of the connection ID +const MaxConnIDLen = 20 + +// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using +// AEAD_AES_128_GCM or AEAD_AES_265_GCM. +const InvalidPacketLimitAES = 1 << 52 + +// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305. +const InvalidPacketLimitChaCha = 1 << 36 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream.go new file mode 100644 index 00000000..ad7de864 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream.go @@ -0,0 +1,76 @@ +package protocol + +// StreamType encodes if this is a unidirectional or bidirectional stream +type StreamType uint8 + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni StreamType = iota + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi +) + +// InvalidPacketNumber is a stream ID that is invalid. +// The first valid stream ID in QUIC is 0. +const InvalidStreamID StreamID = -1 + +// StreamNum is the stream number +type StreamNum int64 + +const ( + // InvalidStreamNum is an invalid stream number. + InvalidStreamNum = -1 + // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames + // and as the stream count in the transport parameters + MaxStreamCount StreamNum = 1 << 60 +) + +// StreamID calculates the stream ID. +func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { + if s == 0 { + return InvalidStreamID + } + var first StreamID + switch stype { + case StreamTypeBidi: + switch pers { + case PerspectiveClient: + first = 0 + case PerspectiveServer: + first = 1 + } + case StreamTypeUni: + switch pers { + case PerspectiveClient: + first = 2 + case PerspectiveServer: + first = 3 + } + } + return first + 4*StreamID(s-1) +} + +// A StreamID in QUIC +type StreamID int64 + +// InitiatedBy says if the stream was initiated by the client or by the server +func (s StreamID) InitiatedBy() Perspective { + if s%2 == 0 { + return PerspectiveClient + } + return PerspectiveServer +} + +// Type says if this is a unidirectional or bidirectional stream +func (s StreamID) Type() StreamType { + if s%4 >= 2 { + return StreamTypeUni + } + return StreamTypeBidi +} + +// StreamNum returns how many streams in total are below this +// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) +func (s StreamID) StreamNum() StreamNum { + return StreamNum(s/4) + 1 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go new file mode 100644 index 00000000..dd54dbd3 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go @@ -0,0 +1,114 @@ +package protocol + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "math" +) + +// VersionNumber is a version number as int +type VersionNumber uint32 + +// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions +const ( + gquicVersion0 = 0x51303030 + maxGquicVersion = 0x51303439 +) + +// The version numbers, making grepping easier +const ( + VersionTLS VersionNumber = 0x1 + VersionWhatever VersionNumber = math.MaxUint32 - 1 // for when the version doesn't matter + VersionUnknown VersionNumber = math.MaxUint32 + VersionDraft29 VersionNumber = 0xff00001d + Version1 VersionNumber = 0x1 + Version2 VersionNumber = 0x709a50c4 +) + +// SupportedVersions lists the versions that the server supports +// must be in sorted descending order +var SupportedVersions = []VersionNumber{Version1, Version2, VersionDraft29} + +// IsValidVersion says if the version is known to quic-go +func IsValidVersion(v VersionNumber) bool { + return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) +} + +func (vn VersionNumber) String() string { + // For releases, VersionTLS will be set to a draft version. + // A switch statement can't contain duplicate cases. + if vn == VersionTLS && VersionTLS != VersionDraft29 && VersionTLS != Version1 { + return "TLS dev version (WIP)" + } + //nolint:exhaustive + switch vn { + case VersionWhatever: + return "whatever" + case VersionUnknown: + return "unknown" + case VersionDraft29: + return "draft-29" + case Version1: + return "v1" + case Version2: + return "v2" + default: + if vn.isGQUIC() { + return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) + } + return fmt.Sprintf("%#x", uint32(vn)) + } +} + +func (vn VersionNumber) isGQUIC() bool { + return vn > gquicVersion0 && vn <= maxGquicVersion +} + +func (vn VersionNumber) toGQUICVersion() int { + return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) +} + +// IsSupportedVersion returns true if the server supports this version +func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { + for _, t := range supported { + if t == v { + return true + } + } + return false +} + +// ChooseSupportedVersion finds the best version in the overlap of ours and theirs +// ours is a slice of versions that we support, sorted by our preference (descending) +// theirs is a slice of versions offered by the peer. The order does not matter. +// The bool returned indicates if a matching version was found. +func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) { + for _, ourVer := range ours { + for _, theirVer := range theirs { + if ourVer == theirVer { + return ourVer, true + } + } + } + return 0, false +} + +// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) +func generateReservedVersion() VersionNumber { + b := make([]byte, 4) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) +} + +// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position +func GetGreasedVersions(supported []VersionNumber) []VersionNumber { + b := make([]byte, 1) + _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything + randPos := int(b[0]) % (len(supported) + 1) + greased := make([]VersionNumber, len(supported)+1) + copy(greased, supported[:randPos]) + greased[randPos] = generateReservedVersion() + copy(greased[randPos+1:], supported[randPos:]) + return greased +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qerr/error_codes.go b/vendor/github.com/lucas-clemente/quic-go/internal/qerr/error_codes.go new file mode 100644 index 00000000..bee42d51 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/qerr/error_codes.go @@ -0,0 +1,88 @@ +package qerr + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/qtls" +) + +// TransportErrorCode is a QUIC transport error. +type TransportErrorCode uint64 + +// The error codes defined by QUIC +const ( + NoError TransportErrorCode = 0x0 + InternalError TransportErrorCode = 0x1 + ConnectionRefused TransportErrorCode = 0x2 + FlowControlError TransportErrorCode = 0x3 + StreamLimitError TransportErrorCode = 0x4 + StreamStateError TransportErrorCode = 0x5 + FinalSizeError TransportErrorCode = 0x6 + FrameEncodingError TransportErrorCode = 0x7 + TransportParameterError TransportErrorCode = 0x8 + ConnectionIDLimitError TransportErrorCode = 0x9 + ProtocolViolation TransportErrorCode = 0xa + InvalidToken TransportErrorCode = 0xb + ApplicationErrorErrorCode TransportErrorCode = 0xc + CryptoBufferExceeded TransportErrorCode = 0xd + KeyUpdateError TransportErrorCode = 0xe + AEADLimitReached TransportErrorCode = 0xf + NoViablePathError TransportErrorCode = 0x10 +) + +func (e TransportErrorCode) IsCryptoError() bool { + return e >= 0x100 && e < 0x200 +} + +// Message is a description of the error. +// It only returns a non-empty string for crypto errors. +func (e TransportErrorCode) Message() string { + if !e.IsCryptoError() { + return "" + } + return qtls.Alert(e - 0x100).Error() +} + +func (e TransportErrorCode) String() string { + switch e { + case NoError: + return "NO_ERROR" + case InternalError: + return "INTERNAL_ERROR" + case ConnectionRefused: + return "CONNECTION_REFUSED" + case FlowControlError: + return "FLOW_CONTROL_ERROR" + case StreamLimitError: + return "STREAM_LIMIT_ERROR" + case StreamStateError: + return "STREAM_STATE_ERROR" + case FinalSizeError: + return "FINAL_SIZE_ERROR" + case FrameEncodingError: + return "FRAME_ENCODING_ERROR" + case TransportParameterError: + return "TRANSPORT_PARAMETER_ERROR" + case ConnectionIDLimitError: + return "CONNECTION_ID_LIMIT_ERROR" + case ProtocolViolation: + return "PROTOCOL_VIOLATION" + case InvalidToken: + return "INVALID_TOKEN" + case ApplicationErrorErrorCode: + return "APPLICATION_ERROR" + case CryptoBufferExceeded: + return "CRYPTO_BUFFER_EXCEEDED" + case KeyUpdateError: + return "KEY_UPDATE_ERROR" + case AEADLimitReached: + return "AEAD_LIMIT_REACHED" + case NoViablePathError: + return "NO_VIABLE_PATH" + default: + if e.IsCryptoError() { + return fmt.Sprintf("CRYPTO_ERROR (%#x)", uint16(e)) + } + return fmt.Sprintf("unknown error code: %#x", uint16(e)) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qerr/errors.go b/vendor/github.com/lucas-clemente/quic-go/internal/qerr/errors.go new file mode 100644 index 00000000..8b1cff98 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/qerr/errors.go @@ -0,0 +1,124 @@ +package qerr + +import ( + "fmt" + "net" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +var ( + ErrHandshakeTimeout = &HandshakeTimeoutError{} + ErrIdleTimeout = &IdleTimeoutError{} +) + +type TransportError struct { + Remote bool + FrameType uint64 + ErrorCode TransportErrorCode + ErrorMessage string +} + +var _ error = &TransportError{} + +// NewCryptoError create a new TransportError instance for a crypto error +func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError { + return &TransportError{ + ErrorCode: 0x100 + TransportErrorCode(tlsAlert), + ErrorMessage: errorMessage, + } +} + +func (e *TransportError) Error() string { + str := e.ErrorCode.String() + if e.FrameType != 0 { + str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) + } + msg := e.ErrorMessage + if len(msg) == 0 { + msg = e.ErrorCode.Message() + } + if len(msg) == 0 { + return str + } + return str + ": " + msg +} + +func (e *TransportError) Is(target error) bool { + return target == net.ErrClosed +} + +// An ApplicationErrorCode is an application-defined error code. +type ApplicationErrorCode uint64 + +func (e *ApplicationError) Is(target error) bool { + return target == net.ErrClosed +} + +// A StreamErrorCode is an error code used to cancel streams. +type StreamErrorCode uint64 + +type ApplicationError struct { + Remote bool + ErrorCode ApplicationErrorCode + ErrorMessage string +} + +var _ error = &ApplicationError{} + +func (e *ApplicationError) Error() string { + if len(e.ErrorMessage) == 0 { + return fmt.Sprintf("Application error %#x", e.ErrorCode) + } + return fmt.Sprintf("Application error %#x: %s", e.ErrorCode, e.ErrorMessage) +} + +type IdleTimeoutError struct{} + +var _ error = &IdleTimeoutError{} + +func (e *IdleTimeoutError) Timeout() bool { return true } +func (e *IdleTimeoutError) Temporary() bool { return false } +func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } +func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed } + +type HandshakeTimeoutError struct{} + +var _ error = &HandshakeTimeoutError{} + +func (e *HandshakeTimeoutError) Timeout() bool { return true } +func (e *HandshakeTimeoutError) Temporary() bool { return false } +func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } +func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed } + +// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. +type VersionNegotiationError struct { + Ours []protocol.VersionNumber + Theirs []protocol.VersionNumber +} + +func (e *VersionNegotiationError) Error() string { + return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs) +} + +func (e *VersionNegotiationError) Is(target error) bool { + return target == net.ErrClosed +} + +// A StatelessResetError occurs when we receive a stateless reset. +type StatelessResetError struct { + Token protocol.StatelessResetToken +} + +var _ net.Error = &StatelessResetError{} + +func (e *StatelessResetError) Error() string { + return fmt.Sprintf("received a stateless reset with token %x", e.Token) +} + +func (e *StatelessResetError) Is(target error) bool { + return target == net.ErrClosed +} + +func (e *StatelessResetError) Timeout() bool { return false } +func (e *StatelessResetError) Temporary() bool { return true } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go116.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go116.go new file mode 100644 index 00000000..e3024624 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go116.go @@ -0,0 +1,100 @@ +//go:build go1.16 && !go1.17 +// +build go1.16,!go1.17 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-16" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-16.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go117.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go117.go new file mode 100644 index 00000000..bc385f19 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go117.go @@ -0,0 +1,100 @@ +//go:build go1.17 && !go1.18 +// +build go1.17,!go1.18 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-17" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-17.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go118.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go118.go new file mode 100644 index 00000000..5de030c7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go118.go @@ -0,0 +1,100 @@ +//go:build go1.18 && !go1.19 +// +build go1.18,!go1.19 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-18" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains inforamtion about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-18.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go119.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go119.go new file mode 100644 index 00000000..86dcaea3 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go119.go @@ -0,0 +1,100 @@ +//go:build go1.19 +// +build go1.19 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "net" + "unsafe" + + "github.com/marten-seemann/qtls-go1-19" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains information about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-19.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go120.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go120.go new file mode 100644 index 00000000..f8d59d8b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go120.go @@ -0,0 +1,6 @@ +//go:build go1.20 +// +build go1.20 + +package qtls + +var _ int = "The version of quic-go you're using can't be built on Go 1.20 yet. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go_oldversion.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go_oldversion.go new file mode 100644 index 00000000..384d719c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go_oldversion.go @@ -0,0 +1,7 @@ +//go:build (go1.9 || go1.10 || go1.11 || go1.12 || go1.13 || go1.14 || go1.15) && !go1.16 +// +build go1.9 go1.10 go1.11 go1.12 go1.13 go1.14 go1.15 +// +build !go1.16 + +package qtls + +var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/atomic_bool.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/atomic_bool.go new file mode 100644 index 00000000..cf464250 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/atomic_bool.go @@ -0,0 +1,22 @@ +package utils + +import "sync/atomic" + +// An AtomicBool is an atomic bool +type AtomicBool struct { + v int32 +} + +// Set sets the value +func (a *AtomicBool) Set(value bool) { + var n int32 + if value { + n = 1 + } + atomic.StoreInt32(&a.v, n) +} + +// Get gets the value +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.v) != 0 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/buffered_write_closer.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/buffered_write_closer.go new file mode 100644 index 00000000..b5b9d6fc --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/buffered_write_closer.go @@ -0,0 +1,26 @@ +package utils + +import ( + "bufio" + "io" +) + +type bufferedWriteCloser struct { + *bufio.Writer + io.Closer +} + +// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer +func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser { + return &bufferedWriteCloser{ + Writer: writer, + Closer: closer, + } +} + +func (h bufferedWriteCloser) Close() error { + if err := h.Writer.Flush(); err != nil { + return err + } + return h.Closer.Close() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go new file mode 100644 index 00000000..096023ef --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// ByteIntervalElement is an element of a linked list. +type ByteIntervalElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *ByteIntervalElement + + // The list to which this element belongs. + list *ByteIntervalList + + // The value stored with this element. + Value ByteInterval +} + +// Next returns the next list element or nil. +func (e *ByteIntervalElement) Next() *ByteIntervalElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *ByteIntervalElement) Prev() *ByteIntervalElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// ByteIntervalList is a linked list of ByteIntervals. +type ByteIntervalList struct { + root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *ByteIntervalList) Init() *ByteIntervalList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewByteIntervalList returns an initialized list. +func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *ByteIntervalList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *ByteIntervalList) Front() *ByteIntervalElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *ByteIntervalList) Back() *ByteIntervalElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *ByteIntervalList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement { + return l.insert(&ByteIntervalElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *ByteIntervalList) PushFront(v ByteInterval) *ByteIntervalElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go new file mode 100644 index 00000000..d1f52842 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go @@ -0,0 +1,17 @@ +package utils + +import ( + "bytes" + "io" +) + +// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. +type ByteOrder interface { + ReadUint32(io.ByteReader) (uint32, error) + ReadUint24(io.ByteReader) (uint32, error) + ReadUint16(io.ByteReader) (uint16, error) + + WriteUint32(*bytes.Buffer, uint32) + WriteUint24(*bytes.Buffer, uint32) + WriteUint16(*bytes.Buffer, uint16) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go new file mode 100644 index 00000000..d05542e1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go @@ -0,0 +1,89 @@ +package utils + +import ( + "bytes" + "io" +) + +// BigEndian is the big-endian implementation of ByteOrder. +var BigEndian ByteOrder = bigEndian{} + +type bigEndian struct{} + +var _ ByteOrder = &bigEndian{} + +// ReadUintN reads N bytes +func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) { + var res uint64 + for i := uint8(0); i < length; i++ { + bt, err := b.ReadByte() + if err != nil { + return 0, err + } + res ^= uint64(bt) << ((length - 1 - i) * 8) + } + return res, nil +} + +// ReadUint32 reads a uint32 +func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) { + var b1, b2, b3, b4 uint8 + var err error + if b4, err = b.ReadByte(); err != nil { + return 0, err + } + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil +} + +// ReadUint24 reads a uint24 +func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) { + var b1, b2, b3 uint8 + var err error + if b3, err = b.ReadByte(); err != nil { + return 0, err + } + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil +} + +// ReadUint16 reads a uint16 +func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { + var b1, b2 uint8 + var err error + if b2, err = b.ReadByte(); err != nil { + return 0, err + } + if b1, err = b.ReadByte(); err != nil { + return 0, err + } + return uint16(b1) + uint16(b2)<<8, nil +} + +// WriteUint32 writes a uint32 +func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint24 writes a uint24 +func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) { + b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)}) +} + +// WriteUint16 writes a uint16 +func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) { + b.Write([]byte{uint8(i >> 8), uint8(i)}) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go new file mode 100644 index 00000000..8a63e958 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go @@ -0,0 +1,5 @@ +package utils + +//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out byteinterval_linkedlist.go gen Item=ByteInterval +//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out packetinterval_linkedlist.go gen Item=PacketInterval +//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out newconnectionid_linkedlist.go gen Item=NewConnectionID diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/ip.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/ip.go new file mode 100644 index 00000000..7ac7ffec --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/ip.go @@ -0,0 +1,10 @@ +package utils + +import "net" + +func IsIPv4(ip net.IP) bool { + // If ip is not an IPv4 address, To4 returns nil. + // Note that there might be some corner cases, where this is not correct. + // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. + return ip.To4() != nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go new file mode 100644 index 00000000..e27f01b4 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go @@ -0,0 +1,131 @@ +package utils + +import ( + "fmt" + "log" + "os" + "strings" + "time" +) + +// LogLevel of quic-go +type LogLevel uint8 + +const ( + // LogLevelNothing disables + LogLevelNothing LogLevel = iota + // LogLevelError enables err logs + LogLevelError + // LogLevelInfo enables info logs (e.g. packets) + LogLevelInfo + // LogLevelDebug enables debug logs (e.g. packet contents) + LogLevelDebug +) + +const logEnv = "QUIC_GO_LOG_LEVEL" + +// A Logger logs. +type Logger interface { + SetLogLevel(LogLevel) + SetLogTimeFormat(format string) + WithPrefix(prefix string) Logger + Debug() bool + + Errorf(format string, args ...interface{}) + Infof(format string, args ...interface{}) + Debugf(format string, args ...interface{}) +} + +// DefaultLogger is used by quic-go for logging. +var DefaultLogger Logger + +type defaultLogger struct { + prefix string + + logLevel LogLevel + timeFormat string +} + +var _ Logger = &defaultLogger{} + +// SetLogLevel sets the log level +func (l *defaultLogger) SetLogLevel(level LogLevel) { + l.logLevel = level +} + +// SetLogTimeFormat sets the format of the timestamp +// an empty string disables the logging of timestamps +func (l *defaultLogger) SetLogTimeFormat(format string) { + log.SetFlags(0) // disable timestamp logging done by the log package + l.timeFormat = format +} + +// Debugf logs something +func (l *defaultLogger) Debugf(format string, args ...interface{}) { + if l.logLevel == LogLevelDebug { + l.logMessage(format, args...) + } +} + +// Infof logs something +func (l *defaultLogger) Infof(format string, args ...interface{}) { + if l.logLevel >= LogLevelInfo { + l.logMessage(format, args...) + } +} + +// Errorf logs something +func (l *defaultLogger) Errorf(format string, args ...interface{}) { + if l.logLevel >= LogLevelError { + l.logMessage(format, args...) + } +} + +func (l *defaultLogger) logMessage(format string, args ...interface{}) { + var pre string + + if len(l.timeFormat) > 0 { + pre = time.Now().Format(l.timeFormat) + " " + } + if len(l.prefix) > 0 { + pre += l.prefix + " " + } + log.Printf(pre+format, args...) +} + +func (l *defaultLogger) WithPrefix(prefix string) Logger { + if len(l.prefix) > 0 { + prefix = l.prefix + " " + prefix + } + return &defaultLogger{ + logLevel: l.logLevel, + timeFormat: l.timeFormat, + prefix: prefix, + } +} + +// Debug returns true if the log level is LogLevelDebug +func (l *defaultLogger) Debug() bool { + return l.logLevel == LogLevelDebug +} + +func init() { + DefaultLogger = &defaultLogger{} + DefaultLogger.SetLogLevel(readLoggingEnv()) +} + +func readLoggingEnv() LogLevel { + switch strings.ToLower(os.Getenv(logEnv)) { + case "": + return LogLevelNothing + case "debug": + return LogLevelDebug + case "info": + return LogLevelInfo + case "error": + return LogLevelError + default: + fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") + return LogLevelNothing + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go new file mode 100644 index 00000000..ee1f85f6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go @@ -0,0 +1,170 @@ +package utils + +import ( + "math" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// InfDuration is a duration of infinite length +const InfDuration = time.Duration(math.MaxInt64) + +// Max returns the maximum of two Ints +func Max(a, b int) int { + if a < b { + return b + } + return a +} + +// MaxUint32 returns the maximum of two uint32 +func MaxUint32(a, b uint32) uint32 { + if a < b { + return b + } + return a +} + +// MaxUint64 returns the maximum of two uint64 +func MaxUint64(a, b uint64) uint64 { + if a < b { + return b + } + return a +} + +// MinUint64 returns the maximum of two uint64 +func MinUint64(a, b uint64) uint64 { + if a < b { + return a + } + return b +} + +// Min returns the minimum of two Ints +func Min(a, b int) int { + if a < b { + return a + } + return b +} + +// MinUint32 returns the maximum of two uint32 +func MinUint32(a, b uint32) uint32 { + if a < b { + return a + } + return b +} + +// MinInt64 returns the minimum of two int64 +func MinInt64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +// MaxInt64 returns the minimum of two int64 +func MaxInt64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +// MinByteCount returns the minimum of two ByteCounts +func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount { + if a < b { + return a + } + return b +} + +// MaxByteCount returns the maximum of two ByteCounts +func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount { + if a < b { + return b + } + return a +} + +// MaxDuration returns the max duration +func MaxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} + +// MinDuration returns the minimum duration +func MinDuration(a, b time.Duration) time.Duration { + if a > b { + return b + } + return a +} + +// MinNonZeroDuration return the minimum duration that's not zero. +func MinNonZeroDuration(a, b time.Duration) time.Duration { + if a == 0 { + return b + } + if b == 0 { + return a + } + return MinDuration(a, b) +} + +// AbsDuration returns the absolute value of a time duration +func AbsDuration(d time.Duration) time.Duration { + if d >= 0 { + return d + } + return -d +} + +// MinTime returns the earlier time +func MinTime(a, b time.Time) time.Time { + if a.After(b) { + return b + } + return a +} + +// MinNonZeroTime returns the earlist time that is not time.Time{} +// If both a and b are time.Time{}, it returns time.Time{} +func MinNonZeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b + } + if b.IsZero() { + return a + } + return MinTime(a, b) +} + +// MaxTime returns the later time +func MaxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} + +// MaxPacketNumber returns the max packet number +func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { + if a > b { + return a + } + return b +} + +// MinPacketNumber returns the min packet number +func MinPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { + if a < b { + return a + } + return b +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/new_connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/new_connection_id.go new file mode 100644 index 00000000..694ee7aa --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/new_connection_id.go @@ -0,0 +1,12 @@ +package utils + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// NewConnectionID is a new connection ID +type NewConnectionID struct { + SequenceNumber uint64 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/newconnectionid_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/newconnectionid_linkedlist.go new file mode 100644 index 00000000..d59562e5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/newconnectionid_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// NewConnectionIDElement is an element of a linked list. +type NewConnectionIDElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *NewConnectionIDElement + + // The list to which this element belongs. + list *NewConnectionIDList + + // The value stored with this element. + Value NewConnectionID +} + +// Next returns the next list element or nil. +func (e *NewConnectionIDElement) Next() *NewConnectionIDElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *NewConnectionIDElement) Prev() *NewConnectionIDElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// NewConnectionIDList is a linked list of NewConnectionIDs. +type NewConnectionIDList struct { + root NewConnectionIDElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *NewConnectionIDList) Init() *NewConnectionIDList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewNewConnectionIDList returns an initialized list. +func NewNewConnectionIDList() *NewConnectionIDList { return new(NewConnectionIDList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *NewConnectionIDList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *NewConnectionIDList) Front() *NewConnectionIDElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *NewConnectionIDList) Back() *NewConnectionIDElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *NewConnectionIDList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *NewConnectionIDList) insert(e, at *NewConnectionIDElement) *NewConnectionIDElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *NewConnectionIDList) insertValue(v NewConnectionID, at *NewConnectionIDElement) *NewConnectionIDElement { + return l.insert(&NewConnectionIDElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *NewConnectionIDList) remove(e *NewConnectionIDElement) *NewConnectionIDElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *NewConnectionIDList) Remove(e *NewConnectionIDElement) NewConnectionID { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *NewConnectionIDList) PushFront(v NewConnectionID) *NewConnectionIDElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *NewConnectionIDList) PushBack(v NewConnectionID) *NewConnectionIDElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *NewConnectionIDList) InsertBefore(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *NewConnectionIDList) InsertAfter(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *NewConnectionIDList) MoveToFront(e *NewConnectionIDElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *NewConnectionIDList) MoveToBack(e *NewConnectionIDElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *NewConnectionIDList) MoveBefore(e, mark *NewConnectionIDElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *NewConnectionIDList) MoveAfter(e, mark *NewConnectionIDElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *NewConnectionIDList) PushBackList(other *NewConnectionIDList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *NewConnectionIDList) PushFrontList(other *NewConnectionIDList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go new file mode 100644 index 00000000..62cc8b9c --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go @@ -0,0 +1,9 @@ +package utils + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +// PacketInterval is an interval from one PacketNumber to the other +type PacketInterval struct { + Start protocol.PacketNumber + End protocol.PacketNumber +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go new file mode 100644 index 00000000..b461e85a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go @@ -0,0 +1,217 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package utils + +// Linked list implementation from the Go standard library. + +// PacketIntervalElement is an element of a linked list. +type PacketIntervalElement struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *PacketIntervalElement + + // The list to which this element belongs. + list *PacketIntervalList + + // The value stored with this element. + Value PacketInterval +} + +// Next returns the next list element or nil. +func (e *PacketIntervalElement) Next() *PacketIntervalElement { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *PacketIntervalElement) Prev() *PacketIntervalElement { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// PacketIntervalList is a linked list of PacketIntervals. +type PacketIntervalList struct { + root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *PacketIntervalList) Init() *PacketIntervalList { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// NewPacketIntervalList returns an initialized list. +func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *PacketIntervalList) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *PacketIntervalList) Front() *PacketIntervalElement { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *PacketIntervalList) Back() *PacketIntervalElement { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *PacketIntervalList) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketIntervalElement { + n := at.next + at.next = e + e.prev = at + e.next = n + n.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement { + return l.insert(&PacketIntervalElement{Value: v}, at) +} + +// remove removes e from its list, decrements l.len, and returns e. +func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalElement { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + return e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *PacketIntervalList) PushFront(v PacketInterval) *PacketIntervalElement { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.insert(l.remove(e), l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { + if e.list != l || e == mark || mark.list != l { + return + } + l.insert(l.remove(e), mark) +} + +// PushBackList inserts a copy of an other list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of an other list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/rand.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/rand.go new file mode 100644 index 00000000..30069144 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/rand.go @@ -0,0 +1,29 @@ +package utils + +import ( + "crypto/rand" + "encoding/binary" +) + +// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand. +type Rand struct { + buf [4]byte +} + +func (r *Rand) Int31() int32 { + rand.Read(r.buf[:]) + return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31)) +} + +// copied from the standard library math/rand implementation of Int63n +func (r *Rand) Int31n(n int32) int32 { + if n&(n-1) == 0 { // n is power of two, can mask + return r.Int31() & (n - 1) + } + max := int32((1 << 31) - 1 - (1<<31)%uint32(n)) + v := r.Int31() + for v > max { + v = r.Int31() + } + return v % n +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/rtt_stats.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/rtt_stats.go new file mode 100644 index 00000000..66642ba8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/rtt_stats.go @@ -0,0 +1,127 @@ +package utils + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +const ( + rttAlpha = 0.125 + oneMinusAlpha = 1 - rttAlpha + rttBeta = 0.25 + oneMinusBeta = 1 - rttBeta + // The default RTT used before an RTT sample is taken. + defaultInitialRTT = 100 * time.Millisecond +) + +// RTTStats provides round-trip statistics +type RTTStats struct { + hasMeasurement bool + + minRTT time.Duration + latestRTT time.Duration + smoothedRTT time.Duration + meanDeviation time.Duration + + maxAckDelay time.Duration +} + +// NewRTTStats makes a properly initialized RTTStats object +func NewRTTStats() *RTTStats { + return &RTTStats{} +} + +// MinRTT Returns the minRTT for the entire connection. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } + +// LatestRTT returns the most recent rtt measurement. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT } + +// SmoothedRTT returns the smoothed RTT for the connection. +// May return Zero if no valid updates have occurred. +func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT } + +// MeanDeviation gets the mean deviation +func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation } + +// MaxAckDelay gets the max_ack_delay advertised by the peer +func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay } + +// PTO gets the probe timeout duration. +func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { + if r.SmoothedRTT() == 0 { + return 2 * defaultInitialRTT + } + pto := r.SmoothedRTT() + MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + if includeMaxAckDelay { + pto += r.MaxAckDelay() + } + return pto +} + +// UpdateRTT updates the RTT based on a new sample. +func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { + if sendDelta == InfDuration || sendDelta <= 0 { + return + } + + // Update r.minRTT first. r.minRTT does not use an rttSample corrected for + // ackDelay but the raw observed sendDelta, since poor clock granularity at + // the client may cause a high ackDelay to result in underestimation of the + // r.minRTT. + if r.minRTT == 0 || r.minRTT > sendDelta { + r.minRTT = sendDelta + } + + // Correct for ackDelay if information received from the peer results in a + // an RTT sample at least as large as minRTT. Otherwise, only use the + // sendDelta. + sample := sendDelta + if sample-r.minRTT >= ackDelay { + sample -= ackDelay + } + r.latestRTT = sample + // First time call. + if !r.hasMeasurement { + r.hasMeasurement = true + r.smoothedRTT = sample + r.meanDeviation = sample / 2 + } else { + r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond + r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond + } +} + +// SetMaxAckDelay sets the max_ack_delay +func (r *RTTStats) SetMaxAckDelay(mad time.Duration) { + r.maxAckDelay = mad +} + +// SetInitialRTT sets the initial RTT. +// It is used during the 0-RTT handshake when restoring the RTT stats from the session state. +func (r *RTTStats) SetInitialRTT(t time.Duration) { + if r.hasMeasurement { + panic("initial RTT set after first measurement") + } + r.smoothedRTT = t + r.latestRTT = t +} + +// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. +func (r *RTTStats) OnConnectionMigration() { + r.latestRTT = 0 + r.minRTT = 0 + r.smoothedRTT = 0 + r.meanDeviation = 0 +} + +// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt +// is larger. The mean deviation is increased to the most recent deviation if +// it's larger. +func (r *RTTStats) ExpireSmoothedMetrics() { + r.meanDeviation = MaxDuration(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) + r.smoothedRTT = MaxDuration(r.smoothedRTT, r.latestRTT) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go new file mode 100644 index 00000000..ec16d251 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go @@ -0,0 +1,9 @@ +package utils + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +// ByteInterval is an interval from one ByteCount to the other +type ByteInterval struct { + Start protocol.ByteCount + End protocol.ByteCount +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go new file mode 100644 index 00000000..a4f5e67a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go @@ -0,0 +1,53 @@ +package utils + +import ( + "math" + "time" +) + +// A Timer wrapper that behaves correctly when resetting +type Timer struct { + t *time.Timer + read bool + deadline time.Time +} + +// NewTimer creates a new timer that is not set +func NewTimer() *Timer { + return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))} +} + +// Chan returns the channel of the wrapped timer +func (t *Timer) Chan() <-chan time.Time { + return t.t.C +} + +// Reset the timer, no matter whether the value was read or not +func (t *Timer) Reset(deadline time.Time) { + if deadline.Equal(t.deadline) && !t.read { + // No need to reset the timer + return + } + + // We need to drain the timer if the value from its channel was not read yet. + // See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU + if !t.t.Stop() && !t.read { + <-t.t.C + } + if !deadline.IsZero() { + t.t.Reset(time.Until(deadline)) + } + + t.read = false + t.deadline = deadline +} + +// SetRead should be called after the value from the chan was read +func (t *Timer) SetRead() { + t.read = true +} + +// Stop stops the timer +func (t *Timer) Stop() { + t.t.Stop() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go new file mode 100644 index 00000000..e5280a73 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go @@ -0,0 +1,251 @@ +package wire + +import ( + "bytes" + "errors" + "sort" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") + +// An AckFrame is an ACK frame +type AckFrame struct { + AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last + DelayTime time.Duration + + ECT0, ECT1, ECNCE uint64 +} + +// parseAckFrame reads an ACK frame +func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + ecn := typeByte&0x1 > 0 + + frame := &AckFrame{} + + la, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + largestAcked := protocol.PacketNumber(la) + delay, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + delayTime := time.Duration(delay*1< largestAcked { + return nil, errors.New("invalid first ACK range") + } + smallest := largestAcked - ackBlock + + // read all the other ACK ranges + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked}) + for i := uint64(0); i < numBlocks; i++ { + g, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + gap := protocol.PacketNumber(g) + if smallest < gap+2 { + return nil, errInvalidAckRanges + } + largest := smallest - gap - 2 + + ab, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + ackBlock := protocol.PacketNumber(ab) + + if ackBlock > largest { + return nil, errInvalidAckRanges + } + smallest = largest - ackBlock + frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) + } + + if !frame.validateAckRanges() { + return nil, errInvalidAckRanges + } + + // parse (and skip) the ECN section + if ecn { + for i := 0; i < 3; i++ { + if _, err := quicvarint.Read(r); err != nil { + return nil, err + } + } + } + + return frame, nil +} + +// Write writes an ACK frame. +func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 + if hasECN { + b.WriteByte(0x3) + } else { + b.WriteByte(0x2) + } + quicvarint.Write(b, uint64(f.LargestAcked())) + quicvarint.Write(b, encodeAckDelay(f.DelayTime)) + + numRanges := f.numEncodableAckRanges() + quicvarint.Write(b, uint64(numRanges-1)) + + // write the first range + _, firstRange := f.encodeAckRange(0) + quicvarint.Write(b, firstRange) + + // write all the other range + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + quicvarint.Write(b, gap) + quicvarint.Write(b, len) + } + + if hasECN { + quicvarint.Write(b, f.ECT0) + quicvarint.Write(b, f.ECT1) + quicvarint.Write(b, f.ECNCE) + } + return nil +} + +// Length of a written frame +func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + largestAcked := f.AckRanges[0].Largest + numRanges := f.numEncodableAckRanges() + + length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime)) + + length += quicvarint.Len(uint64(numRanges - 1)) + lowestInFirstRange := f.AckRanges[0].Smallest + length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange)) + + for i := 1; i < numRanges; i++ { + gap, len := f.encodeAckRange(i) + length += quicvarint.Len(gap) + length += quicvarint.Len(len) + } + if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 { + length += quicvarint.Len(f.ECT0) + length += quicvarint.Len(f.ECT1) + length += quicvarint.Len(f.ECNCE) + } + return length +} + +// gets the number of ACK ranges that can be encoded +// such that the resulting frame is smaller than the maximum ACK frame size +func (f *AckFrame) numEncodableAckRanges() int { + length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime)) + length += 2 // assume that the number of ranges will consume 2 bytes + for i := 1; i < len(f.AckRanges); i++ { + gap, len := f.encodeAckRange(i) + rangeLen := quicvarint.Len(gap) + quicvarint.Len(len) + if length+rangeLen > protocol.MaxAckFrameSize { + // Writing range i would exceed the MaxAckFrameSize. + // So encode one range less than that. + return i - 1 + } + length += rangeLen + } + return len(f.AckRanges) +} + +func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) { + if i == 0 { + return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest) + } + return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2), + uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest) +} + +// HasMissingRanges returns if this frame reports any missing packets +func (f *AckFrame) HasMissingRanges() bool { + return len(f.AckRanges) > 1 +} + +func (f *AckFrame) validateAckRanges() bool { + if len(f.AckRanges) == 0 { + return false + } + + // check the validity of every single ACK range + for _, ackRange := range f.AckRanges { + if ackRange.Smallest > ackRange.Largest { + return false + } + } + + // check the consistency for ACK with multiple NACK ranges + for i, ackRange := range f.AckRanges { + if i == 0 { + continue + } + lastAckRange := f.AckRanges[i-1] + if lastAckRange.Smallest <= ackRange.Smallest { + return false + } + if lastAckRange.Smallest <= ackRange.Largest+1 { + return false + } + } + + return true +} + +// LargestAcked is the largest acked packet number +func (f *AckFrame) LargestAcked() protocol.PacketNumber { + return f.AckRanges[0].Largest +} + +// LowestAcked is the lowest acked packet number +func (f *AckFrame) LowestAcked() protocol.PacketNumber { + return f.AckRanges[len(f.AckRanges)-1].Smallest +} + +// AcksPacket determines if this ACK frame acks a certain packet number +func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { + if p < f.LowestAcked() || p > f.LargestAcked() { + return false + } + + i := sort.Search(len(f.AckRanges), func(i int) bool { + return p >= f.AckRanges[i].Smallest + }) + // i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked + return p <= f.AckRanges[i].Largest +} + +func encodeAckDelay(delay time.Duration) uint64 { + return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go new file mode 100644 index 00000000..0f418580 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go @@ -0,0 +1,14 @@ +package wire + +import "github.com/lucas-clemente/quic-go/internal/protocol" + +// AckRange is an ACK range +type AckRange struct { + Smallest protocol.PacketNumber + Largest protocol.PacketNumber +} + +// Len returns the number of packets contained in this ACK range +func (r AckRange) Len() protocol.PacketNumber { + return r.Largest - r.Smallest + 1 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go new file mode 100644 index 00000000..4ce49af6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go @@ -0,0 +1,83 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A ConnectionCloseFrame is a CONNECTION_CLOSE frame +type ConnectionCloseFrame struct { + IsApplicationError bool + ErrorCode uint64 + FrameType uint64 + ReasonPhrase string +} + +func parseConnectionCloseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d} + ec, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.ErrorCode = ec + // read the Frame Type, if this is not an application error + if !f.IsApplicationError { + ft, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.FrameType = ft + } + var reasonPhraseLen uint64 + reasonPhraseLen, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + // shortcut to prevent the unnecessary allocation of dataLen bytes + // if the dataLen is larger than the remaining length of the packet + // reading the whole reason phrase would result in EOF when attempting to READ + if int(reasonPhraseLen) > r.Len() { + return nil, io.EOF + } + + reasonPhrase := make([]byte, reasonPhraseLen) + if _, err := io.ReadFull(r, reasonPhrase); err != nil { + // this should never happen, since we already checked the reasonPhraseLen earlier + return nil, err + } + f.ReasonPhrase = string(reasonPhrase) + return f, nil +} + +// Length of a written frame +func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount { + length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) + if !f.IsApplicationError { + length += quicvarint.Len(f.FrameType) // for the frame type + } + return length +} + +func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if f.IsApplicationError { + b.WriteByte(0x1d) + } else { + b.WriteByte(0x1c) + } + + quicvarint.Write(b, f.ErrorCode) + if !f.IsApplicationError { + quicvarint.Write(b, f.FrameType) + } + quicvarint.Write(b, uint64(len(f.ReasonPhrase))) + b.WriteString(f.ReasonPhrase) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go new file mode 100644 index 00000000..6301c878 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go @@ -0,0 +1,102 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A CryptoFrame is a CRYPTO frame +type CryptoFrame struct { + Offset protocol.ByteCount + Data []byte +} + +func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &CryptoFrame{} + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + frame.Offset = protocol.ByteCount(offset) + dataLen, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if dataLen > uint64(r.Len()) { + return nil, io.EOF + } + if dataLen != 0 { + frame.Data = make([]byte, dataLen) + if _, err := io.ReadFull(r, frame.Data); err != nil { + // this should never happen, since we already checked the dataLen earlier + return nil, err + } + } + return frame, nil +} + +func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x6) + quicvarint.Write(b, uint64(f.Offset)) + quicvarint.Write(b, uint64(len(f.Data))) + b.Write(f.Data) + return nil +} + +// Length of a written frame +func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data)) +} + +// MaxDataLen returns the maximum data length +func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen := 1 + quicvarint.Len(uint64(f.Offset)) + 1 + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. +// It returns if the frame was actually split. +// The frame might not be split if: +// * the size is large enough to fit the whole frame +// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. +func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) { + if f.Length(version) <= maxSize { + return nil, false + } + + n := f.MaxDataLen(maxSize) + if n == 0 { + return nil, true + } + + newLen := protocol.ByteCount(len(f.Data)) - n + + new := &CryptoFrame{} + new.Offset = f.Offset + new.Data = make([]byte, newLen) + + // swap the data slices + new.Data, f.Data = f.Data, new.Data + + copy(f.Data, new.Data[n:]) + new.Data = new.Data[:n] + f.Offset += n + + return new, true +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go new file mode 100644 index 00000000..459f04d1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go @@ -0,0 +1,38 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A DataBlockedFrame is a DATA_BLOCKED frame +type DataBlockedFrame struct { + MaximumData protocol.ByteCount +} + +func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + return &DataBlockedFrame{ + MaximumData: protocol.ByteCount(offset), + }, nil +} + +func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + typeByte := uint8(0x14) + b.WriteByte(typeByte) + quicvarint.Write(b, uint64(f.MaximumData)) + return nil +} + +// Length of a written frame +func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.MaximumData)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/datagram_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/datagram_frame.go new file mode 100644 index 00000000..9d6e55cb --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/datagram_frame.go @@ -0,0 +1,85 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A DatagramFrame is a DATAGRAM frame +type DatagramFrame struct { + DataLenPresent bool + Data []byte +} + +func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &DatagramFrame{} + f.DataLenPresent = typeByte&0x1 > 0 + + var length uint64 + if f.DataLenPresent { + var err error + len, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if len > uint64(r.Len()) { + return nil, io.EOF + } + length = len + } else { + length = uint64(r.Len()) + } + f.Data = make([]byte, length) + if _, err := io.ReadFull(r, f.Data); err != nil { + return nil, err + } + return f, nil +} + +func (f *DatagramFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := uint8(0x30) + if f.DataLenPresent { + typeByte ^= 0x1 + } + b.WriteByte(typeByte) + if f.DataLenPresent { + quicvarint.Write(b, uint64(len(f.Data))) + } + b.Write(f.Data) + return nil +} + +// MaxDataLen returns the maximum data length +func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + headerLen := protocol.ByteCount(1) + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// Length of a written frame +func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + length := 1 + protocol.ByteCount(len(f.Data)) + if f.DataLenPresent { + length += quicvarint.Len(uint64(len(f.Data))) + } + return length +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go new file mode 100644 index 00000000..9d9edab2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go @@ -0,0 +1,249 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// ErrInvalidReservedBits is returned when the reserved bits are incorrect. +// When this error is returned, parsing continues, and an ExtendedHeader is returned. +// This is necessary because we need to decrypt the packet in that case, +// in order to avoid a timing side-channel. +var ErrInvalidReservedBits = errors.New("invalid reserved bits") + +// ExtendedHeader is the header of a QUIC packet. +type ExtendedHeader struct { + Header + + typeByte byte + + KeyPhase protocol.KeyPhaseBit + + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + + parsedLen protocol.ByteCount +} + +func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { + startLen := b.Len() + // read the (now unencrypted) first byte + var err error + h.typeByte, err = b.ReadByte() + if err != nil { + return false, err + } + if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { + return false, err + } + var reservedBitsValid bool + if h.IsLongHeader { + reservedBitsValid, err = h.parseLongHeader(b, v) + } else { + reservedBitsValid, err = h.parseShortHeader(b, v) + } + if err != nil { + return false, err + } + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return reservedBitsValid, err +} + +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { + if err := h.readPacketNumber(b); err != nil { + return false, err + } + if h.typeByte&0xc != 0 { + return false, nil + } + return true, nil +} + +func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { + h.KeyPhase = protocol.KeyPhaseZero + if h.typeByte&0x4 > 0 { + h.KeyPhase = protocol.KeyPhaseOne + } + + if err := h.readPacketNumber(b); err != nil { + return false, err + } + if h.typeByte&0x18 != 0 { + return false, nil + } + return true, nil +} + +func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { + h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + n, err := b.ReadByte() + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen2: + n, err := utils.BigEndian.ReadUint16(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen3: + n, err := utils.BigEndian.ReadUint24(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen4: + n, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + return nil +} + +// Write writes the Header. +func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { + if h.DestConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) + } + if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) + } + if h.IsLongHeader { + return h.writeLongHeader(b, ver) + } + return h.writeShortHeader(b, ver) +} + +func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error { + var packetType uint8 + if version == protocol.Version2 { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b01 + case protocol.PacketType0RTT: + packetType = 0b10 + case protocol.PacketTypeHandshake: + packetType = 0b11 + case protocol.PacketTypeRetry: + packetType = 0b00 + } + } else { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b00 + case protocol.PacketType0RTT: + packetType = 0b01 + case protocol.PacketTypeHandshake: + packetType = 0b10 + case protocol.PacketTypeRetry: + packetType = 0b11 + } + } + firstByte := 0xc0 | packetType<<4 + if h.Type != protocol.PacketTypeRetry { + // Retry packets don't have a packet number + firstByte |= uint8(h.PacketNumberLen - 1) + } + + b.WriteByte(firstByte) + utils.BigEndian.WriteUint32(b, uint32(h.Version)) + b.WriteByte(uint8(h.DestConnectionID.Len())) + b.Write(h.DestConnectionID.Bytes()) + b.WriteByte(uint8(h.SrcConnectionID.Len())) + b.Write(h.SrcConnectionID.Bytes()) + + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeRetry: + b.Write(h.Token) + return nil + case protocol.PacketTypeInitial: + quicvarint.Write(b, uint64(len(h.Token))) + b.Write(h.Token) + } + quicvarint.WriteWithLen(b, uint64(h.Length), 2) + return h.writePacketNumber(b) +} + +func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { + typeByte := 0x40 | uint8(h.PacketNumberLen-1) + if h.KeyPhase == protocol.KeyPhaseOne { + typeByte |= byte(1 << 2) + } + + b.WriteByte(typeByte) + b.Write(h.DestConnectionID.Bytes()) + return h.writePacketNumber(b) +} + +func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(h.PacketNumber)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) + case protocol.PacketNumberLen3: + utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + return nil +} + +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + +// GetLength determines the length of the Header. +func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { + if h.IsLongHeader { + length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ + if h.Type == protocol.PacketTypeInitial { + length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) + } + return length + } + + length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) + length += protocol.ByteCount(h.PacketNumberLen) + return length +} + +// Log logs the Header +func (h *ExtendedHeader) Log(logger utils.Logger) { + if h.IsLongHeader { + var token string + if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { + if len(h.Token) == 0 { + token = "Token: (empty), " + } else { + token = fmt.Sprintf("Token: %#x, ", h.Token) + } + if h.Type == protocol.PacketTypeRetry { + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) + return + } + } + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) + } else { + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go new file mode 100644 index 00000000..f3a51ecb --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go @@ -0,0 +1,143 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" +) + +type frameParser struct { + ackDelayExponent uint8 + + supportsDatagrams bool + + version protocol.VersionNumber +} + +// NewFrameParser creates a new frame parser. +func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParser { + return &frameParser{ + supportsDatagrams: supportsDatagrams, + version: v, + } +} + +// ParseNext parses the next frame. +// It skips PADDING frames. +func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { + for r.Len() != 0 { + typeByte, _ := r.ReadByte() + if typeByte == 0x0 { // PADDING frame + continue + } + r.UnreadByte() + + f, err := p.parseFrame(r, typeByte, encLevel) + if err != nil { + return nil, &qerr.TransportError{ + FrameType: uint64(typeByte), + ErrorCode: qerr.FrameEncodingError, + ErrorMessage: err.Error(), + } + } + return f, nil + } + return nil, nil +} + +func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) { + var frame Frame + var err error + if typeByte&0xf8 == 0x8 { + frame, err = parseStreamFrame(r, p.version) + } else { + switch typeByte { + case 0x1: + frame, err = parsePingFrame(r, p.version) + case 0x2, 0x3: + ackDelayExponent := p.ackDelayExponent + if encLevel != protocol.Encryption1RTT { + ackDelayExponent = protocol.DefaultAckDelayExponent + } + frame, err = parseAckFrame(r, ackDelayExponent, p.version) + case 0x4: + frame, err = parseResetStreamFrame(r, p.version) + case 0x5: + frame, err = parseStopSendingFrame(r, p.version) + case 0x6: + frame, err = parseCryptoFrame(r, p.version) + case 0x7: + frame, err = parseNewTokenFrame(r, p.version) + case 0x10: + frame, err = parseMaxDataFrame(r, p.version) + case 0x11: + frame, err = parseMaxStreamDataFrame(r, p.version) + case 0x12, 0x13: + frame, err = parseMaxStreamsFrame(r, p.version) + case 0x14: + frame, err = parseDataBlockedFrame(r, p.version) + case 0x15: + frame, err = parseStreamDataBlockedFrame(r, p.version) + case 0x16, 0x17: + frame, err = parseStreamsBlockedFrame(r, p.version) + case 0x18: + frame, err = parseNewConnectionIDFrame(r, p.version) + case 0x19: + frame, err = parseRetireConnectionIDFrame(r, p.version) + case 0x1a: + frame, err = parsePathChallengeFrame(r, p.version) + case 0x1b: + frame, err = parsePathResponseFrame(r, p.version) + case 0x1c, 0x1d: + frame, err = parseConnectionCloseFrame(r, p.version) + case 0x1e: + frame, err = parseHandshakeDoneFrame(r, p.version) + case 0x30, 0x31: + if p.supportsDatagrams { + frame, err = parseDatagramFrame(r, p.version) + break + } + fallthrough + default: + err = errors.New("unknown frame type") + } + } + if err != nil { + return nil, err + } + if !p.isAllowedAtEncLevel(frame, encLevel) { + return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) + } + return frame, nil +} + +func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { + switch encLevel { + case protocol.EncryptionInitial, protocol.EncryptionHandshake: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame: + return true + default: + return false + } + case protocol.Encryption0RTT: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: + return false + default: + return true + } + case protocol.Encryption1RTT: + return true + default: + panic("unknown encryption level") + } +} + +func (p *frameParser) SetAckDelayExponent(exp uint8) { + p.ackDelayExponent = exp +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/handshake_done_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/handshake_done_frame.go new file mode 100644 index 00000000..158d659f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/handshake_done_frame.go @@ -0,0 +1,28 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A HandshakeDoneFrame is a HANDSHAKE_DONE frame +type HandshakeDoneFrame struct{} + +// ParseHandshakeDoneFrame parses a HandshakeDone frame +func parseHandshakeDoneFrame(r *bytes.Reader, _ protocol.VersionNumber) (*HandshakeDoneFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &HandshakeDoneFrame{}, nil +} + +func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x1e) + return nil +} + +// Length of a written frame +func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go new file mode 100644 index 00000000..f6a31ee0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go @@ -0,0 +1,274 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// ParseConnectionID parses the destination connection ID of a packet. +// It uses the data slice for the connection ID. +// That means that the connection ID must not be used after the packet buffer is released. +func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { + if len(data) == 0 { + return nil, io.EOF + } + isLongHeader := data[0]&0x80 > 0 + if !isLongHeader { + if len(data) < shortHeaderConnIDLen+1 { + return nil, io.EOF + } + return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil + } + if len(data) < 6 { + return nil, io.EOF + } + destConnIDLen := int(data[5]) + if len(data) < 6+destConnIDLen { + return nil, io.EOF + } + return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil +} + +// IsVersionNegotiationPacket says if this is a version negotiation packet +func IsVersionNegotiationPacket(b []byte) bool { + if len(b) < 5 { + return false + } + return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 +} + +// Is0RTTPacket says if this is a 0-RTT packet. +// A packet sent with a version we don't understand can never be a 0-RTT packet. +func Is0RTTPacket(b []byte) bool { + if len(b) < 5 { + return false + } + if b[0]&0x80 == 0 { + return false + } + version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) + if !protocol.IsSupportedVersion(protocol.SupportedVersions, version) { + return false + } + if version == protocol.Version2 { + return b[0]>>4&0b11 == 0b10 + } + return b[0]>>4&0b11 == 0b01 +} + +var ErrUnsupportedVersion = errors.New("unsupported version") + +// The Header is the version independent part of the header +type Header struct { + IsLongHeader bool + typeByte byte + Type protocol.PacketType + + Version protocol.VersionNumber + SrcConnectionID protocol.ConnectionID + DestConnectionID protocol.ConnectionID + + Length protocol.ByteCount + + Token []byte + + parsedLen protocol.ByteCount // how many bytes were read while parsing this header +} + +// ParsePacket parses a packet. +// If the packet has a long header, the packet is cut according to the length field. +// If we understand the version, the packet is header up unto the packet number. +// Otherwise, only the invariant part of the header is parsed. +func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) { + hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen) + if err != nil { + if err == ErrUnsupportedVersion { + return hdr, nil, nil, ErrUnsupportedVersion + } + return nil, nil, nil, err + } + var rest []byte + if hdr.IsLongHeader { + if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { + return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + packetLen := int(hdr.ParsedLen() + hdr.Length) + rest = data[packetLen:] + data = data[:packetLen] + } + return hdr, data, rest, nil +} + +// ParseHeader parses the header. +// For short header packets: up to the packet number. +// For long header packets: +// * if we understand the version: up to the packet number +// * if not, only the invariant part of the header +func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { + startLen := b.Len() + h, err := parseHeaderImpl(b, shortHeaderConnIDLen) + if err != nil { + return h, err + } + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return h, err +} + +func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { + typeByte, err := b.ReadByte() + if err != nil { + return nil, err + } + + h := &Header{ + typeByte: typeByte, + IsLongHeader: typeByte&0x80 > 0, + } + + if !h.IsLongHeader { + if h.typeByte&0x40 == 0 { + return nil, errors.New("not a QUIC packet") + } + if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil { + return nil, err + } + return h, nil + } + return h, h.parseLongHeader(b) +} + +func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error { + var err error + h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen) + return err +} + +func (h *Header) parseLongHeader(b *bytes.Reader) error { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return err + } + h.Version = protocol.VersionNumber(v) + if h.Version != 0 && h.typeByte&0x40 == 0 { + return errors.New("not a QUIC packet") + } + destConnIDLen, err := b.ReadByte() + if err != nil { + return err + } + h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen)) + if err != nil { + return err + } + srcConnIDLen, err := b.ReadByte() + if err != nil { + return err + } + h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen)) + if err != nil { + return err + } + if h.Version == 0 { // version negotiation packet + return nil + } + // If we don't understand the version, we have no idea how to interpret the rest of the bytes + if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) { + return ErrUnsupportedVersion + } + + if h.Version == protocol.Version2 { + switch h.typeByte >> 4 & 0b11 { + case 0b00: + h.Type = protocol.PacketTypeRetry + case 0b01: + h.Type = protocol.PacketTypeInitial + case 0b10: + h.Type = protocol.PacketType0RTT + case 0b11: + h.Type = protocol.PacketTypeHandshake + } + } else { + switch h.typeByte >> 4 & 0b11 { + case 0b00: + h.Type = protocol.PacketTypeInitial + case 0b01: + h.Type = protocol.PacketType0RTT + case 0b10: + h.Type = protocol.PacketTypeHandshake + case 0b11: + h.Type = protocol.PacketTypeRetry + } + } + + if h.Type == protocol.PacketTypeRetry { + tokenLen := b.Len() - 16 + if tokenLen <= 0 { + return io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return err + } + _, err := b.Seek(16, io.SeekCurrent) + return err + } + + if h.Type == protocol.PacketTypeInitial { + tokenLen, err := quicvarint.Read(b) + if err != nil { + return err + } + if tokenLen > uint64(b.Len()) { + return io.EOF + } + h.Token = make([]byte, tokenLen) + if _, err := io.ReadFull(b, h.Token); err != nil { + return err + } + } + + pl, err := quicvarint.Read(b) + if err != nil { + return err + } + h.Length = protocol.ByteCount(pl) + return nil +} + +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *Header) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + +// ParseExtended parses the version dependent part of the header. +// The Reader has to be set such that it points to the first byte of the header. +func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { + extHdr := h.toExtendedHeader() + reservedBitsValid, err := extHdr.parse(b, ver) + if err != nil { + return nil, err + } + if !reservedBitsValid { + return extHdr, ErrInvalidReservedBits + } + return extHdr, nil +} + +func (h *Header) toExtendedHeader() *ExtendedHeader { + return &ExtendedHeader{Header: *h} +} + +// PacketType is the type of the packet, for logging purposes +func (h *Header) PacketType() string { + if h.IsLongHeader { + return h.Type.String() + } + return "1-RTT" +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/interface.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/interface.go new file mode 100644 index 00000000..99fdc80f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/interface.go @@ -0,0 +1,19 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A Frame in QUIC +type Frame interface { + Write(b *bytes.Buffer, version protocol.VersionNumber) error + Length(version protocol.VersionNumber) protocol.ByteCount +} + +// A FrameParser parses QUIC frames, one by one. +type FrameParser interface { + ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error) + SetAckDelayExponent(uint8) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go new file mode 100644 index 00000000..30cf9424 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go @@ -0,0 +1,72 @@ +package wire + +import ( + "fmt" + "strings" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// LogFrame logs a frame, either sent or received +func LogFrame(logger utils.Logger, frame Frame, sent bool) { + if !logger.Debug() { + return + } + dir := "<-" + if sent { + dir = "->" + } + switch f := frame.(type) { + case *CryptoFrame: + dataLen := protocol.ByteCount(len(f.Data)) + logger.Debugf("\t%s &wire.CryptoFrame{Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.Offset, dataLen, f.Offset+dataLen) + case *StreamFrame: + logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, Fin: %t, Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.StreamID, f.Fin, f.Offset, f.DataLen(), f.Offset+f.DataLen()) + case *ResetStreamFrame: + logger.Debugf("\t%s &wire.ResetStreamFrame{StreamID: %d, ErrorCode: %#x, FinalSize: %d}", dir, f.StreamID, f.ErrorCode, f.FinalSize) + case *AckFrame: + hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 + var ecn string + if hasECN { + ecn = fmt.Sprintf(", ECT0: %d, ECT1: %d, CE: %d", f.ECT0, f.ECT1, f.ECNCE) + } + if len(f.AckRanges) > 1 { + ackRanges := make([]string, len(f.AckRanges)) + for i, r := range f.AckRanges { + ackRanges[i] = fmt.Sprintf("{Largest: %d, Smallest: %d}", r.Largest, r.Smallest) + } + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, AckRanges: {%s}, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String(), ecn) + } else { + logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String(), ecn) + } + case *MaxDataFrame: + logger.Debugf("\t%s &wire.MaxDataFrame{MaximumData: %d}", dir, f.MaximumData) + case *MaxStreamDataFrame: + logger.Debugf("\t%s &wire.MaxStreamDataFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) + case *DataBlockedFrame: + logger.Debugf("\t%s &wire.DataBlockedFrame{MaximumData: %d}", dir, f.MaximumData) + case *StreamDataBlockedFrame: + logger.Debugf("\t%s &wire.StreamDataBlockedFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData) + case *MaxStreamsFrame: + switch f.Type { + case protocol.StreamTypeUni: + logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: uni, MaxStreamNum: %d}", dir, f.MaxStreamNum) + case protocol.StreamTypeBidi: + logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: %d}", dir, f.MaxStreamNum) + } + case *StreamsBlockedFrame: + switch f.Type { + case protocol.StreamTypeUni: + logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: uni, MaxStreams: %d}", dir, f.StreamLimit) + case protocol.StreamTypeBidi: + logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) + } + case *NewConnectionIDFrame: + logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken) + case *NewTokenFrame: + logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) + default: + logger.Debugf("\t%s %#v", dir, frame) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go new file mode 100644 index 00000000..a9a09248 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go @@ -0,0 +1,40 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A MaxDataFrame carries flow control information for the connection +type MaxDataFrame struct { + MaximumData protocol.ByteCount +} + +// parseMaxDataFrame parses a MAX_DATA frame +func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + frame := &MaxDataFrame{} + byteOffset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + frame.MaximumData = protocol.ByteCount(byteOffset) + return frame, nil +} + +// Write writes a MAX_STREAM_DATA frame +func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x10) + quicvarint.Write(b, uint64(f.MaximumData)) + return nil +} + +// Length of a written frame +func (f *MaxDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.MaximumData)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go new file mode 100644 index 00000000..728ecbe8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go @@ -0,0 +1,46 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A MaxStreamDataFrame is a MAX_STREAM_DATA frame +type MaxStreamDataFrame struct { + StreamID protocol.StreamID + MaximumStreamData protocol.ByteCount +} + +func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &MaxStreamDataFrame{ + StreamID: protocol.StreamID(sid), + MaximumStreamData: protocol.ByteCount(offset), + }, nil +} + +func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x11) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.MaximumStreamData)) + return nil +} + +// Length of a written frame +func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_streams_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_streams_frame.go new file mode 100644 index 00000000..73d7e13e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_streams_frame.go @@ -0,0 +1,55 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A MaxStreamsFrame is a MAX_STREAMS frame +type MaxStreamsFrame struct { + Type protocol.StreamType + MaxStreamNum protocol.StreamNum +} + +func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &MaxStreamsFrame{} + switch typeByte { + case 0x12: + f.Type = protocol.StreamTypeBidi + case 0x13: + f.Type = protocol.StreamTypeUni + } + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.MaxStreamNum = protocol.StreamNum(streamID) + if f.MaxStreamNum > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) + } + return f, nil +} + +func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + switch f.Type { + case protocol.StreamTypeBidi: + b.WriteByte(0x12) + case protocol.StreamTypeUni: + b.WriteByte(0x13) + } + quicvarint.Write(b, uint64(f.MaxStreamNum)) + return nil +} + +// Length of a written frame +func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.MaxStreamNum)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_connection_id_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_connection_id_frame.go new file mode 100644 index 00000000..1a017ba9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_connection_id_frame.go @@ -0,0 +1,80 @@ +package wire + +import ( + "bytes" + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A NewConnectionIDFrame is a NEW_CONNECTION_ID frame +type NewConnectionIDFrame struct { + SequenceNumber uint64 + RetirePriorTo uint64 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} + +func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + seq, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + ret, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if ret > seq { + //nolint:stylecheck + return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq) + } + connIDLen, err := r.ReadByte() + if err != nil { + return nil, err + } + if connIDLen > protocol.MaxConnIDLen { + return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + connID, err := protocol.ReadConnectionID(r, int(connIDLen)) + if err != nil { + return nil, err + } + frame := &NewConnectionIDFrame{ + SequenceNumber: seq, + RetirePriorTo: ret, + ConnectionID: connID, + } + if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + + return frame, nil +} + +func (f *NewConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x18) + quicvarint.Write(b, f.SequenceNumber) + quicvarint.Write(b, f.RetirePriorTo) + connIDLen := f.ConnectionID.Len() + if connIDLen > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + b.WriteByte(uint8(connIDLen)) + b.Write(f.ConnectionID.Bytes()) + b.Write(f.StatelessResetToken[:]) + return nil +} + +// Length of a written frame +func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_token_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_token_frame.go new file mode 100644 index 00000000..3d5d5c3a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_token_frame.go @@ -0,0 +1,48 @@ +package wire + +import ( + "bytes" + "errors" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A NewTokenFrame is a NEW_TOKEN frame +type NewTokenFrame struct { + Token []byte +} + +func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + tokenLen, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + if uint64(r.Len()) < tokenLen { + return nil, io.EOF + } + if tokenLen == 0 { + return nil, errors.New("token must not be empty") + } + token := make([]byte, int(tokenLen)) + if _, err := io.ReadFull(r, token); err != nil { + return nil, err + } + return &NewTokenFrame{Token: token}, nil +} + +func (f *NewTokenFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x7) + quicvarint.Write(b, uint64(len(f.Token))) + b.Write(f.Token) + return nil +} + +// Length of a written frame +func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go new file mode 100644 index 00000000..5ec82177 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go @@ -0,0 +1,38 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A PathChallengeFrame is a PATH_CHALLENGE frame +type PathChallengeFrame struct { + Data [8]byte +} + +func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathChallengeFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x1a) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go new file mode 100644 index 00000000..262819f8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go @@ -0,0 +1,38 @@ +package wire + +import ( + "bytes" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A PathResponseFrame is a PATH_RESPONSE frame +type PathResponseFrame struct { + Data [8]byte +} + +func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + frame := &PathResponseFrame{} + if _, err := io.ReadFull(r, frame.Data[:]); err != nil { + if err == io.ErrUnexpectedEOF { + return nil, io.EOF + } + return nil, err + } + return frame, nil +} + +func (f *PathResponseFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x1b) + b.Write(f.Data[:]) + return nil +} + +// Length of a written frame +func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + 8 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go new file mode 100644 index 00000000..dc029e45 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go @@ -0,0 +1,27 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// A PingFrame is a PING frame +type PingFrame struct{} + +func parsePingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PingFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + return &PingFrame{}, nil +} + +func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x1) + return nil +} + +// Length of a written frame +func (f *PingFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/pool.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/pool.go new file mode 100644 index 00000000..c057395e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/pool.go @@ -0,0 +1,33 @@ +package wire + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +var pool sync.Pool + +func init() { + pool.New = func() interface{} { + return &StreamFrame{ + Data: make([]byte, 0, protocol.MaxPacketBufferSize), + fromPool: true, + } + } +} + +func GetStreamFrame() *StreamFrame { + f := pool.Get().(*StreamFrame) + return f +} + +func putStreamFrame(f *StreamFrame) { + if !f.fromPool { + return + } + if protocol.ByteCount(cap(f.Data)) != protocol.MaxPacketBufferSize { + panic("wire.PutStreamFrame called with packet of wrong size!") + } + pool.Put(f) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/reset_stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/reset_stream_frame.go new file mode 100644 index 00000000..69bbc2b9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/reset_stream_frame.go @@ -0,0 +1,58 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A ResetStreamFrame is a RESET_STREAM frame in QUIC +type ResetStreamFrame struct { + StreamID protocol.StreamID + ErrorCode qerr.StreamErrorCode + FinalSize protocol.ByteCount +} + +func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) { + if _, err := r.ReadByte(); err != nil { // read the TypeByte + return nil, err + } + + var streamID protocol.StreamID + var byteOffset protocol.ByteCount + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + streamID = protocol.StreamID(sid) + errorCode, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + bo, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + byteOffset = protocol.ByteCount(bo) + + return &ResetStreamFrame{ + StreamID: streamID, + ErrorCode: qerr.StreamErrorCode(errorCode), + FinalSize: byteOffset, + }, nil +} + +func (f *ResetStreamFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x4) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.ErrorCode)) + quicvarint.Write(b, uint64(f.FinalSize)) + return nil +} + +// Length of a written frame +func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/retire_connection_id_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/retire_connection_id_frame.go new file mode 100644 index 00000000..0f7e58c8 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/retire_connection_id_frame.go @@ -0,0 +1,36 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame +type RetireConnectionIDFrame struct { + SequenceNumber uint64 +} + +func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + seq, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + return &RetireConnectionIDFrame{SequenceNumber: seq}, nil +} + +func (f *RetireConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x19) + quicvarint.Write(b, f.SequenceNumber) + return nil +} + +// Length of a written frame +func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(f.SequenceNumber) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go new file mode 100644 index 00000000..fb1160c1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go @@ -0,0 +1,48 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A StopSendingFrame is a STOP_SENDING frame +type StopSendingFrame struct { + StreamID protocol.StreamID + ErrorCode qerr.StreamErrorCode +} + +// parseStopSendingFrame parses a STOP_SENDING frame +func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + errorCode, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &StopSendingFrame{ + StreamID: protocol.StreamID(streamID), + ErrorCode: qerr.StreamErrorCode(errorCode), + }, nil +} + +// Length of a written frame +func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) +} + +func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + b.WriteByte(0x5) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.ErrorCode)) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_data_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_data_blocked_frame.go new file mode 100644 index 00000000..dc6d631a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_data_blocked_frame.go @@ -0,0 +1,46 @@ +package wire + +import ( + "bytes" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame +type StreamDataBlockedFrame struct { + StreamID protocol.StreamID + MaximumStreamData protocol.ByteCount +} + +func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) { + if _, err := r.ReadByte(); err != nil { + return nil, err + } + + sid, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + offset, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + + return &StreamDataBlockedFrame{ + StreamID: protocol.StreamID(sid), + MaximumStreamData: protocol.ByteCount(offset), + }, nil +} + +func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + b.WriteByte(0x15) + quicvarint.Write(b, uint64(f.StreamID)) + quicvarint.Write(b, uint64(f.MaximumStreamData)) + return nil +} + +// Length of a written frame +func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go new file mode 100644 index 00000000..66340d16 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go @@ -0,0 +1,189 @@ +package wire + +import ( + "bytes" + "errors" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A StreamFrame of QUIC +type StreamFrame struct { + StreamID protocol.StreamID + Offset protocol.ByteCount + Data []byte + Fin bool + DataLenPresent bool + + fromPool bool +} + +func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + hasOffset := typeByte&0x4 > 0 + fin := typeByte&0x1 > 0 + hasDataLen := typeByte&0x2 > 0 + + streamID, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + var offset uint64 + if hasOffset { + offset, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + } + + var dataLen uint64 + if hasDataLen { + var err error + dataLen, err = quicvarint.Read(r) + if err != nil { + return nil, err + } + } else { + // The rest of the packet is data + dataLen = uint64(r.Len()) + } + + var frame *StreamFrame + if dataLen < protocol.MinStreamFrameBufferSize { + frame = &StreamFrame{Data: make([]byte, dataLen)} + } else { + frame = GetStreamFrame() + // The STREAM frame can't be larger than the StreamFrame we obtained from the buffer, + // since those StreamFrames have a buffer length of the maximum packet size. + if dataLen > uint64(cap(frame.Data)) { + return nil, io.EOF + } + frame.Data = frame.Data[:dataLen] + } + + frame.StreamID = protocol.StreamID(streamID) + frame.Offset = protocol.ByteCount(offset) + frame.Fin = fin + frame.DataLenPresent = hasDataLen + + if dataLen != 0 { + if _, err := io.ReadFull(r, frame.Data); err != nil { + return nil, err + } + } + if frame.Offset+frame.DataLen() > protocol.MaxByteCount { + return nil, errors.New("stream data overflows maximum offset") + } + return frame, nil +} + +// Write writes a STREAM frame +func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { + if len(f.Data) == 0 && !f.Fin { + return errors.New("StreamFrame: attempting to write empty frame without FIN") + } + + typeByte := byte(0x8) + if f.Fin { + typeByte ^= 0x1 + } + hasOffset := f.Offset != 0 + if f.DataLenPresent { + typeByte ^= 0x2 + } + if hasOffset { + typeByte ^= 0x4 + } + b.WriteByte(typeByte) + quicvarint.Write(b, uint64(f.StreamID)) + if hasOffset { + quicvarint.Write(b, uint64(f.Offset)) + } + if f.DataLenPresent { + quicvarint.Write(b, uint64(f.DataLen())) + } + b.Write(f.Data) + return nil +} + +// Length returns the total length of the STREAM frame +func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { + length := 1 + quicvarint.Len(uint64(f.StreamID)) + if f.Offset != 0 { + length += quicvarint.Len(uint64(f.Offset)) + } + if f.DataLenPresent { + length += quicvarint.Len(uint64(f.DataLen())) + } + return length + f.DataLen() +} + +// DataLen gives the length of data in bytes +func (f *StreamFrame) DataLen() protocol.ByteCount { + return protocol.ByteCount(len(f.Data)) +} + +// MaxDataLen returns the maximum data length +// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). +func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { + headerLen := 1 + quicvarint.Len(uint64(f.StreamID)) + if f.Offset != 0 { + headerLen += quicvarint.Len(uint64(f.Offset)) + } + if f.DataLenPresent { + // pretend that the data size will be 1 bytes + // if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards + headerLen++ + } + if headerLen > maxSize { + return 0 + } + maxDataLen := maxSize - headerLen + if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 { + maxDataLen-- + } + return maxDataLen +} + +// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes. +// It returns if the frame was actually split. +// The frame might not be split if: +// * the size is large enough to fit the whole frame +// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. +func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) { + if maxSize >= f.Length(version) { + return nil, false + } + + n := f.MaxDataLen(maxSize, version) + if n == 0 { + return nil, true + } + + new := GetStreamFrame() + new.StreamID = f.StreamID + new.Offset = f.Offset + new.Fin = false + new.DataLenPresent = f.DataLenPresent + + // swap the data slices + new.Data, f.Data = f.Data, new.Data + new.fromPool, f.fromPool = f.fromPool, new.fromPool + + f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n] + copy(f.Data, new.Data[n:]) + new.Data = new.Data[:n] + f.Offset += n + + return new, true +} + +func (f *StreamFrame) PutBack() { + putStreamFrame(f) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/streams_blocked_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/streams_blocked_frame.go new file mode 100644 index 00000000..f4066071 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/streams_blocked_frame.go @@ -0,0 +1,55 @@ +package wire + +import ( + "bytes" + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +// A StreamsBlockedFrame is a STREAMS_BLOCKED frame +type StreamsBlockedFrame struct { + Type protocol.StreamType + StreamLimit protocol.StreamNum +} + +func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) { + typeByte, err := r.ReadByte() + if err != nil { + return nil, err + } + + f := &StreamsBlockedFrame{} + switch typeByte { + case 0x16: + f.Type = protocol.StreamTypeBidi + case 0x17: + f.Type = protocol.StreamTypeUni + } + streamLimit, err := quicvarint.Read(r) + if err != nil { + return nil, err + } + f.StreamLimit = protocol.StreamNum(streamLimit) + if f.StreamLimit > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) + } + return f, nil +} + +func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { + switch f.Type { + case protocol.StreamTypeBidi: + b.WriteByte(0x16) + case protocol.StreamTypeUni: + b.WriteByte(0x17) + } + quicvarint.Write(b, uint64(f.StreamLimit)) + return nil +} + +// Length of a written frame +func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 + quicvarint.Len(uint64(f.StreamLimit)) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/transport_parameters.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/transport_parameters.go new file mode 100644 index 00000000..e1f83cd6 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/transport_parameters.go @@ -0,0 +1,476 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "net" + "sort" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" +) + +const transportParameterMarshalingVersion = 1 + +func init() { + rand.Seed(time.Now().UTC().UnixNano()) +} + +type transportParameterID uint64 + +const ( + originalDestinationConnectionIDParameterID transportParameterID = 0x0 + maxIdleTimeoutParameterID transportParameterID = 0x1 + statelessResetTokenParameterID transportParameterID = 0x2 + maxUDPPayloadSizeParameterID transportParameterID = 0x3 + initialMaxDataParameterID transportParameterID = 0x4 + initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5 + initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6 + initialMaxStreamDataUniParameterID transportParameterID = 0x7 + initialMaxStreamsBidiParameterID transportParameterID = 0x8 + initialMaxStreamsUniParameterID transportParameterID = 0x9 + ackDelayExponentParameterID transportParameterID = 0xa + maxAckDelayParameterID transportParameterID = 0xb + disableActiveMigrationParameterID transportParameterID = 0xc + preferredAddressParameterID transportParameterID = 0xd + activeConnectionIDLimitParameterID transportParameterID = 0xe + initialSourceConnectionIDParameterID transportParameterID = 0xf + retrySourceConnectionIDParameterID transportParameterID = 0x10 + // RFC 9221 + maxDatagramFrameSizeParameterID transportParameterID = 0x20 +) + +// PreferredAddress is the value encoding in the preferred_address transport parameter +type PreferredAddress struct { + IPv4 net.IP + IPv4Port uint16 + IPv6 net.IP + IPv6Port uint16 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} + +// TransportParameters are parameters sent to the peer during the handshake +type TransportParameters struct { + InitialMaxStreamDataBidiLocal protocol.ByteCount + InitialMaxStreamDataBidiRemote protocol.ByteCount + InitialMaxStreamDataUni protocol.ByteCount + InitialMaxData protocol.ByteCount + + MaxAckDelay time.Duration + AckDelayExponent uint8 + + DisableActiveMigration bool + + MaxUDPPayloadSize protocol.ByteCount + + MaxUniStreamNum protocol.StreamNum + MaxBidiStreamNum protocol.StreamNum + + MaxIdleTimeout time.Duration + + PreferredAddress *PreferredAddress + + OriginalDestinationConnectionID protocol.ConnectionID + InitialSourceConnectionID protocol.ConnectionID + RetrySourceConnectionID *protocol.ConnectionID // use a pointer here to distinguish zero-length connection IDs from missing transport parameters + + StatelessResetToken *protocol.StatelessResetToken + ActiveConnectionIDLimit uint64 + + MaxDatagramFrameSize protocol.ByteCount +} + +// Unmarshal the transport parameters +func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { + if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil { + return &qerr.TransportError{ + ErrorCode: qerr.TransportParameterError, + ErrorMessage: err.Error(), + } + } + return nil +} + +func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error { + // needed to check that every parameter is only sent at most once + var parameterIDs []transportParameterID + + var ( + readOriginalDestinationConnectionID bool + readInitialSourceConnectionID bool + ) + + p.AckDelayExponent = protocol.DefaultAckDelayExponent + p.MaxAckDelay = protocol.DefaultMaxAckDelay + p.MaxDatagramFrameSize = protocol.InvalidByteCount + + for r.Len() > 0 { + paramIDInt, err := quicvarint.Read(r) + if err != nil { + return err + } + paramID := transportParameterID(paramIDInt) + paramLen, err := quicvarint.Read(r) + if err != nil { + return err + } + if uint64(r.Len()) < paramLen { + return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen) + } + parameterIDs = append(parameterIDs, paramID) + switch paramID { + case maxIdleTimeoutParameterID, + maxUDPPayloadSizeParameterID, + initialMaxDataParameterID, + initialMaxStreamDataBidiLocalParameterID, + initialMaxStreamDataBidiRemoteParameterID, + initialMaxStreamDataUniParameterID, + initialMaxStreamsBidiParameterID, + initialMaxStreamsUniParameterID, + maxAckDelayParameterID, + activeConnectionIDLimitParameterID, + maxDatagramFrameSizeParameterID, + ackDelayExponentParameterID: + if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { + return err + } + case preferredAddressParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a preferred_address") + } + if err := p.readPreferredAddress(r, int(paramLen)); err != nil { + return err + } + case disableActiveMigrationParameterID: + if paramLen != 0 { + return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen) + } + p.DisableActiveMigration = true + case statelessResetTokenParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a stateless_reset_token") + } + if paramLen != 16 { + return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) + } + var token protocol.StatelessResetToken + r.Read(token[:]) + p.StatelessResetToken = &token + case originalDestinationConnectionIDParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent an original_destination_connection_id") + } + p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + readOriginalDestinationConnectionID = true + case initialSourceConnectionIDParameterID: + p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + readInitialSourceConnectionID = true + case retrySourceConnectionIDParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a retry_source_connection_id") + } + connID, _ := protocol.ReadConnectionID(r, int(paramLen)) + p.RetrySourceConnectionID = &connID + default: + r.Seek(int64(paramLen), io.SeekCurrent) + } + } + + if !fromSessionTicket { + if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { + return errors.New("missing original_destination_connection_id") + } + if p.MaxUDPPayloadSize == 0 { + p.MaxUDPPayloadSize = protocol.MaxByteCount + } + if !readInitialSourceConnectionID { + return errors.New("missing initial_source_connection_id") + } + } + + // check that every transport parameter was sent at most once + sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] }) + for i := 0; i < len(parameterIDs)-1; i++ { + if parameterIDs[i] == parameterIDs[i+1] { + return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i]) + } + } + + return nil +} + +func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { + remainingLen := r.Len() + pa := &PreferredAddress{} + ipv4 := make([]byte, 4) + if _, err := io.ReadFull(r, ipv4); err != nil { + return err + } + pa.IPv4 = net.IP(ipv4) + port, err := utils.BigEndian.ReadUint16(r) + if err != nil { + return err + } + pa.IPv4Port = port + ipv6 := make([]byte, 16) + if _, err := io.ReadFull(r, ipv6); err != nil { + return err + } + pa.IPv6 = net.IP(ipv6) + port, err = utils.BigEndian.ReadUint16(r) + if err != nil { + return err + } + pa.IPv6Port = port + connIDLen, err := r.ReadByte() + if err != nil { + return err + } + if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen { + return fmt.Errorf("invalid connection ID length: %d", connIDLen) + } + connID, err := protocol.ReadConnectionID(r, int(connIDLen)) + if err != nil { + return err + } + pa.ConnectionID = connID + if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil { + return err + } + if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen { + return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead) + } + p.PreferredAddress = pa + return nil +} + +func (p *TransportParameters) readNumericTransportParameter( + r *bytes.Reader, + paramID transportParameterID, + expectedLen int, +) error { + remainingLen := r.Len() + val, err := quicvarint.Read(r) + if err != nil { + return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err) + } + if remainingLen-r.Len() != expectedLen { + return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID) + } + //nolint:exhaustive // This only covers the numeric transport parameters. + switch paramID { + case initialMaxStreamDataBidiLocalParameterID: + p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val) + case initialMaxStreamDataBidiRemoteParameterID: + p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val) + case initialMaxStreamDataUniParameterID: + p.InitialMaxStreamDataUni = protocol.ByteCount(val) + case initialMaxDataParameterID: + p.InitialMaxData = protocol.ByteCount(val) + case initialMaxStreamsBidiParameterID: + p.MaxBidiStreamNum = protocol.StreamNum(val) + if p.MaxBidiStreamNum > protocol.MaxStreamCount { + return fmt.Errorf("initial_max_streams_bidi too large: %d (maximum %d)", p.MaxBidiStreamNum, protocol.MaxStreamCount) + } + case initialMaxStreamsUniParameterID: + p.MaxUniStreamNum = protocol.StreamNum(val) + if p.MaxUniStreamNum > protocol.MaxStreamCount { + return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) + } + case maxIdleTimeoutParameterID: + p.MaxIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) + case maxUDPPayloadSizeParameterID: + if val < 1200 { + return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) + } + p.MaxUDPPayloadSize = protocol.ByteCount(val) + case ackDelayExponentParameterID: + if val > protocol.MaxAckDelayExponent { + return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent) + } + p.AckDelayExponent = uint8(val) + case maxAckDelayParameterID: + if val > uint64(protocol.MaxMaxAckDelay/time.Millisecond) { + return fmt.Errorf("invalid value for max_ack_delay: %dms (maximum %dms)", val, protocol.MaxMaxAckDelay/time.Millisecond) + } + p.MaxAckDelay = time.Duration(val) * time.Millisecond + case activeConnectionIDLimitParameterID: + p.ActiveConnectionIDLimit = val + case maxDatagramFrameSizeParameterID: + p.MaxDatagramFrameSize = protocol.ByteCount(val) + default: + return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID) + } + return nil +} + +// Marshal the transport parameters +func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { + b := &bytes.Buffer{} + + // add a greased value + quicvarint.Write(b, uint64(27+31*rand.Intn(100))) + length := rand.Intn(16) + randomData := make([]byte, length) + rand.Read(randomData) + quicvarint.Write(b, uint64(length)) + b.Write(randomData) + + // initial_max_stream_data_bidi_local + p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + // initial_max_stream_data_bidi_remote + p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + // initial_max_stream_data_uni + p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + // initial_max_data + p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + // initial_max_bidi_streams + p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + // initial_max_uni_streams + p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // idle_timeout + p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) + // max_packet_size + p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) + // max_ack_delay + // Only send it if is different from the default value. + if p.MaxAckDelay != protocol.DefaultMaxAckDelay { + p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) + } + // ack_delay_exponent + // Only send it if is different from the default value. + if p.AckDelayExponent != protocol.DefaultAckDelayExponent { + p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) + } + // disable_active_migration + if p.DisableActiveMigration { + quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) + quicvarint.Write(b, 0) + } + if pers == protocol.PerspectiveServer { + // stateless_reset_token + if p.StatelessResetToken != nil { + quicvarint.Write(b, uint64(statelessResetTokenParameterID)) + quicvarint.Write(b, 16) + b.Write(p.StatelessResetToken[:]) + } + // original_destination_connection_id + quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len())) + b.Write(p.OriginalDestinationConnectionID.Bytes()) + // preferred_address + if p.PreferredAddress != nil { + quicvarint.Write(b, uint64(preferredAddressParameterID)) + quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) + ipv4 := p.PreferredAddress.IPv4 + b.Write(ipv4[len(ipv4)-4:]) + utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) + b.Write(p.PreferredAddress.IPv6) + utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) + b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) + b.Write(p.PreferredAddress.ConnectionID.Bytes()) + b.Write(p.PreferredAddress.StatelessResetToken[:]) + } + } + // active_connection_id_limit + p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + // initial_source_connection_id + quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len())) + b.Write(p.InitialSourceConnectionID.Bytes()) + // retry_source_connection_id + if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { + quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID)) + quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len())) + b.Write(p.RetrySourceConnectionID.Bytes()) + } + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + } + return b.Bytes() +} + +func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { + quicvarint.Write(b, uint64(id)) + quicvarint.Write(b, uint64(quicvarint.Len(val))) + quicvarint.Write(b, val) +} + +// MarshalForSessionTicket marshals the transport parameters we save in the session ticket. +// When sending a 0-RTT enabled TLS session tickets, we need to save the transport parameters. +// The client will remember the transport parameters used in the last session, +// and apply those to the 0-RTT data it sends. +// Saving the transport parameters in the ticket gives the server the option to reject 0-RTT +// if the transport parameters changed. +// Since the session ticket is encrypted, the serialization format is defined by the server. +// For convenience, we use the same format that we also use for sending the transport parameters. +func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) { + quicvarint.Write(b, transportParameterMarshalingVersion) + + // initial_max_stream_data_bidi_local + p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + // initial_max_stream_data_bidi_remote + p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + // initial_max_stream_data_uni + p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + // initial_max_data + p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + // initial_max_bidi_streams + p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + // initial_max_uni_streams + p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // active_connection_id_limit + p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) +} + +// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. +func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error { + version, err := quicvarint.Read(r) + if err != nil { + return err + } + if version != transportParameterMarshalingVersion { + return fmt.Errorf("unknown transport parameter marshaling version: %d", version) + } + return p.unmarshal(r, protocol.PerspectiveServer, true) +} + +// ValidFor0RTT checks if the transport parameters match those saved in the session ticket. +func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { + return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && + p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && + p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && + p.InitialMaxData >= saved.InitialMaxData && + p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && + p.MaxUniStreamNum >= saved.MaxUniStreamNum && + p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit +} + +// String returns a string representation, intended for logging. +func (p *TransportParameters) String() string { + logString := "&wire.TransportParameters{OriginalDestinationConnectionID: %s, InitialSourceConnectionID: %s, " + logParams := []interface{}{p.OriginalDestinationConnectionID, p.InitialSourceConnectionID} + if p.RetrySourceConnectionID != nil { + logString += "RetrySourceConnectionID: %s, " + logParams = append(logParams, p.RetrySourceConnectionID) + } + logString += "InitialMaxStreamDataBidiLocal: %d, InitialMaxStreamDataBidiRemote: %d, InitialMaxStreamDataUni: %d, InitialMaxData: %d, MaxBidiStreamNum: %d, MaxUniStreamNum: %d, MaxIdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s, ActiveConnectionIDLimit: %d" + logParams = append(logParams, []interface{}{p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreamNum, p.MaxUniStreamNum, p.MaxIdleTimeout, p.AckDelayExponent, p.MaxAckDelay, p.ActiveConnectionIDLimit}...) + if p.StatelessResetToken != nil { // the client never sends a stateless reset token + logString += ", StatelessResetToken: %#x" + logParams = append(logParams, *p.StatelessResetToken) + } + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + logString += ", MaxDatagramFrameSize: %d" + logParams = append(logParams, p.MaxDatagramFrameSize) + } + logString += "}" + return fmt.Sprintf(logString, logParams...) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go new file mode 100644 index 00000000..196853e0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go @@ -0,0 +1,54 @@ +package wire + +import ( + "bytes" + "crypto/rand" + "errors" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// ParseVersionNegotiationPacket parses a Version Negotiation packet. +func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.VersionNumber, error) { + hdr, err := parseHeader(b, 0) + if err != nil { + return nil, nil, err + } + if b.Len() == 0 { + //nolint:stylecheck + return nil, nil, errors.New("Version Negotiation packet has empty version list") + } + if b.Len()%4 != 0 { + //nolint:stylecheck + return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") + } + versions := make([]protocol.VersionNumber, b.Len()/4) + for i := 0; b.Len() > 0; i++ { + v, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return nil, nil, err + } + versions[i] = protocol.VersionNumber(v) + } + return hdr, versions, nil +} + +// ComposeVersionNegotiation composes a Version Negotiation +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { + greasedVersions := protocol.GetGreasedVersions(versions) + expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 + buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) + r := make([]byte, 1) + _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. + buf.WriteByte(r[0] | 0x80) + utils.BigEndian.WriteUint32(buf, 0) // version 0 + buf.WriteByte(uint8(destConnID.Len())) + buf.Write(destConnID) + buf.WriteByte(uint8(srcConnID.Len())) + buf.Write(srcConnID) + for _, v := range greasedVersions { + utils.BigEndian.WriteUint32(buf, uint32(v)) + } + return buf.Bytes() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/frame.go b/vendor/github.com/lucas-clemente/quic-go/logging/frame.go new file mode 100644 index 00000000..75705092 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/logging/frame.go @@ -0,0 +1,66 @@ +package logging + +import "github.com/lucas-clemente/quic-go/internal/wire" + +// A Frame is a QUIC frame +type Frame interface{} + +// The AckRange is used within the AckFrame. +// It is a range of packet numbers that is being acknowledged. +type AckRange = wire.AckRange + +type ( + // An AckFrame is an ACK frame. + AckFrame = wire.AckFrame + // A ConnectionCloseFrame is a CONNECTION_CLOSE frame. + ConnectionCloseFrame = wire.ConnectionCloseFrame + // A DataBlockedFrame is a DATA_BLOCKED frame. + DataBlockedFrame = wire.DataBlockedFrame + // A HandshakeDoneFrame is a HANDSHAKE_DONE frame. + HandshakeDoneFrame = wire.HandshakeDoneFrame + // A MaxDataFrame is a MAX_DATA frame. + MaxDataFrame = wire.MaxDataFrame + // A MaxStreamDataFrame is a MAX_STREAM_DATA frame. + MaxStreamDataFrame = wire.MaxStreamDataFrame + // A MaxStreamsFrame is a MAX_STREAMS_FRAME. + MaxStreamsFrame = wire.MaxStreamsFrame + // A NewConnectionIDFrame is a NEW_CONNECTION_ID frame. + NewConnectionIDFrame = wire.NewConnectionIDFrame + // A NewTokenFrame is a NEW_TOKEN frame. + NewTokenFrame = wire.NewTokenFrame + // A PathChallengeFrame is a PATH_CHALLENGE frame. + PathChallengeFrame = wire.PathChallengeFrame + // A PathResponseFrame is a PATH_RESPONSE frame. + PathResponseFrame = wire.PathResponseFrame + // A PingFrame is a PING frame. + PingFrame = wire.PingFrame + // A ResetStreamFrame is a RESET_STREAM frame. + ResetStreamFrame = wire.ResetStreamFrame + // A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame. + RetireConnectionIDFrame = wire.RetireConnectionIDFrame + // A StopSendingFrame is a STOP_SENDING frame. + StopSendingFrame = wire.StopSendingFrame + // A StreamsBlockedFrame is a STREAMS_BLOCKED frame. + StreamsBlockedFrame = wire.StreamsBlockedFrame + // A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame. + StreamDataBlockedFrame = wire.StreamDataBlockedFrame +) + +// A CryptoFrame is a CRYPTO frame. +type CryptoFrame struct { + Offset ByteCount + Length ByteCount +} + +// A StreamFrame is a STREAM frame. +type StreamFrame struct { + StreamID StreamID + Offset ByteCount + Length ByteCount + Fin bool +} + +// A DatagramFrame is a DATAGRAM frame. +type DatagramFrame struct { + Length ByteCount +} diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/interface.go b/vendor/github.com/lucas-clemente/quic-go/logging/interface.go new file mode 100644 index 00000000..f71d68f7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/logging/interface.go @@ -0,0 +1,134 @@ +// Package logging defines a logging interface for quic-go. +// This package should not be considered stable +package logging + +import ( + "context" + "net" + "time" + + "github.com/lucas-clemente/quic-go/internal/utils" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type ( + // A ByteCount is used to count bytes. + ByteCount = protocol.ByteCount + // A ConnectionID is a QUIC Connection ID. + ConnectionID = protocol.ConnectionID + // The EncryptionLevel is the encryption level of a packet. + EncryptionLevel = protocol.EncryptionLevel + // The KeyPhase is the key phase of the 1-RTT keys. + KeyPhase = protocol.KeyPhase + // The KeyPhaseBit is the value of the key phase bit of the 1-RTT packets. + KeyPhaseBit = protocol.KeyPhaseBit + // The PacketNumber is the packet number of a packet. + PacketNumber = protocol.PacketNumber + // The Perspective is the role of a QUIC endpoint (client or server). + Perspective = protocol.Perspective + // A StatelessResetToken is a stateless reset token. + StatelessResetToken = protocol.StatelessResetToken + // The StreamID is the stream ID. + StreamID = protocol.StreamID + // The StreamNum is the number of the stream. + StreamNum = protocol.StreamNum + // The StreamType is the type of the stream (unidirectional or bidirectional). + StreamType = protocol.StreamType + // The VersionNumber is the QUIC version. + VersionNumber = protocol.VersionNumber + + // The Header is the QUIC packet header, before removing header protection. + Header = wire.Header + // The ExtendedHeader is the QUIC packet header, after removing header protection. + ExtendedHeader = wire.ExtendedHeader + // The TransportParameters are QUIC transport parameters. + TransportParameters = wire.TransportParameters + // The PreferredAddress is the preferred address sent in the transport parameters. + PreferredAddress = wire.PreferredAddress + + // A TransportError is a transport-level error code. + TransportError = qerr.TransportErrorCode + // An ApplicationError is an application-defined error code. + ApplicationError = qerr.TransportErrorCode + + // The RTTStats contain statistics used by the congestion controller. + RTTStats = utils.RTTStats +) + +const ( + // KeyPhaseZero is key phase bit 0 + KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero + // KeyPhaseOne is key phase bit 1 + KeyPhaseOne KeyPhaseBit = protocol.KeyPhaseOne +) + +const ( + // PerspectiveServer is used for a QUIC server + PerspectiveServer Perspective = protocol.PerspectiveServer + // PerspectiveClient is used for a QUIC client + PerspectiveClient Perspective = protocol.PerspectiveClient +) + +const ( + // EncryptionInitial is the Initial encryption level + EncryptionInitial EncryptionLevel = protocol.EncryptionInitial + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake EncryptionLevel = protocol.EncryptionHandshake + // Encryption1RTT is the 1-RTT encryption level + Encryption1RTT EncryptionLevel = protocol.Encryption1RTT + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT EncryptionLevel = protocol.Encryption0RTT +) + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni = protocol.StreamTypeUni + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi = protocol.StreamTypeBidi +) + +// A Tracer traces events. +type Tracer interface { + // TracerForConnection requests a new tracer for a connection. + // The ODCID is the original destination connection ID: + // The destination connection ID that the client used on the first Initial packet it sent on this connection. + // If nil is returned, tracing will be disabled for this connection. + TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer + + SentPacket(net.Addr, *Header, ByteCount, []Frame) + DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) +} + +// A ConnectionTracer records events. +type ConnectionTracer interface { + StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) + NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) + ClosedConnection(error) + SentTransportParameters(*TransportParameters) + ReceivedTransportParameters(*TransportParameters) + RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT + SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) + ReceivedVersionNegotiationPacket(*Header, []VersionNumber) + ReceivedRetry(*Header) + ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) + BufferedPacket(PacketType) + DroppedPacket(PacketType, ByteCount, PacketDropReason) + UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) + AcknowledgedPacket(EncryptionLevel, PacketNumber) + LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedCongestionState(CongestionState) + UpdatedPTOCount(value uint32) + UpdatedKeyFromTLS(EncryptionLevel, Perspective) + UpdatedKey(generation KeyPhase, remote bool) + DroppedEncryptionLevel(EncryptionLevel) + DroppedKey(generation KeyPhase) + SetLossTimer(TimerType, EncryptionLevel, time.Time) + LossTimerExpired(TimerType, EncryptionLevel) + LossTimerCanceled() + // Close is called when the connection is closed. + Close() + Debug(name, msg string) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/logging/mockgen.go new file mode 100644 index 00000000..c5aa8a16 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/logging/mockgen.go @@ -0,0 +1,4 @@ +package logging + +//go:generate sh -c "mockgen -package logging -self_package github.com/lucas-clemente/quic-go/logging -destination mock_connection_tracer_test.go github.com/lucas-clemente/quic-go/logging ConnectionTracer" +//go:generate sh -c "mockgen -package logging -self_package github.com/lucas-clemente/quic-go/logging -destination mock_tracer_test.go github.com/lucas-clemente/quic-go/logging Tracer" diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/multiplex.go b/vendor/github.com/lucas-clemente/quic-go/logging/multiplex.go new file mode 100644 index 00000000..8280e8cd --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/logging/multiplex.go @@ -0,0 +1,219 @@ +package logging + +import ( + "context" + "net" + "time" +) + +type tracerMultiplexer struct { + tracers []Tracer +} + +var _ Tracer = &tracerMultiplexer{} + +// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. +func NewMultiplexedTracer(tracers ...Tracer) Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &tracerMultiplexer{tracers} +} + +func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { + var connTracers []ConnectionTracer + for _, t := range m.tracers { + if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { + connTracers = append(connTracers, ct) + } + } + return NewMultiplexedConnectionTracer(connTracers...) +} + +func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range m.tracers { + t.SentPacket(remote, hdr, size, frames) + } +} + +func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range m.tracers { + t.DroppedPacket(remote, typ, size, reason) + } +} + +type connTracerMultiplexer struct { + tracers []ConnectionTracer +} + +var _ ConnectionTracer = &connTracerMultiplexer{} + +// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. +func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &connTracerMultiplexer{tracers: tracers} +} + +func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { + for _, t := range m.tracers { + t.StartedConnection(local, remote, srcConnID, destConnID) + } +} + +func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { + for _, t := range m.tracers { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } +} + +func (m *connTracerMultiplexer) ClosedConnection(e error) { + for _, t := range m.tracers { + t.ClosedConnection(e) + } +} + +func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.SentTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.ReceivedTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { + for _, t := range m.tracers { + t.RestoredTransportParameters(tp) + } +} + +func (m *connTracerMultiplexer) SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { + for _, t := range m.tracers { + t.SentPacket(hdr, size, ack, frames) + } +} + +func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(hdr *Header, versions []VersionNumber) { + for _, t := range m.tracers { + t.ReceivedVersionNegotiationPacket(hdr, versions) + } +} + +func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { + for _, t := range m.tracers { + t.ReceivedRetry(hdr) + } +} + +func (m *connTracerMultiplexer) ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { + for _, t := range m.tracers { + t.ReceivedPacket(hdr, size, frames) + } +} + +func (m *connTracerMultiplexer) BufferedPacket(typ PacketType) { + for _, t := range m.tracers { + t.BufferedPacket(typ) + } +} + +func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range m.tracers { + t.DroppedPacket(typ, size, reason) + } +} + +func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { + for _, t := range m.tracers { + t.UpdatedCongestionState(state) + } +} + +func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { + for _, t := range m.tracers { + t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) + } +} + +func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range m.tracers { + t.AcknowledgedPacket(encLevel, pn) + } +} + +func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range m.tracers { + t.LostPacket(encLevel, pn, reason) + } +} + +func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { + for _, t := range m.tracers { + t.UpdatedPTOCount(value) + } +} + +func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { + for _, t := range m.tracers { + t.UpdatedKeyFromTLS(encLevel, perspective) + } +} + +func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { + for _, t := range m.tracers { + t.UpdatedKey(generation, remote) + } +} + +func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { + for _, t := range m.tracers { + t.DroppedEncryptionLevel(encLevel) + } +} + +func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { + for _, t := range m.tracers { + t.DroppedKey(generation) + } +} + +func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { + for _, t := range m.tracers { + t.SetLossTimer(typ, encLevel, exp) + } +} + +func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { + for _, t := range m.tracers { + t.LossTimerExpired(typ, encLevel) + } +} + +func (m *connTracerMultiplexer) LossTimerCanceled() { + for _, t := range m.tracers { + t.LossTimerCanceled() + } +} + +func (m *connTracerMultiplexer) Debug(name, msg string) { + for _, t := range m.tracers { + t.Debug(name, msg) + } +} + +func (m *connTracerMultiplexer) Close() { + for _, t := range m.tracers { + t.Close() + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/packet_header.go b/vendor/github.com/lucas-clemente/quic-go/logging/packet_header.go new file mode 100644 index 00000000..ea4282fe --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/logging/packet_header.go @@ -0,0 +1,27 @@ +package logging + +import ( + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// PacketTypeFromHeader determines the packet type from a *wire.Header. +func PacketTypeFromHeader(hdr *Header) PacketType { + if !hdr.IsLongHeader { + return PacketType1RTT + } + if hdr.Version == 0 { + return PacketTypeVersionNegotiation + } + switch hdr.Type { + case protocol.PacketTypeInitial: + return PacketTypeInitial + case protocol.PacketTypeHandshake: + return PacketTypeHandshake + case protocol.PacketType0RTT: + return PacketType0RTT + case protocol.PacketTypeRetry: + return PacketTypeRetry + default: + return PacketTypeNotDetermined + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/types.go b/vendor/github.com/lucas-clemente/quic-go/logging/types.go new file mode 100644 index 00000000..ad800692 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/logging/types.go @@ -0,0 +1,94 @@ +package logging + +// PacketType is the packet type of a QUIC packet +type PacketType uint8 + +const ( + // PacketTypeInitial is the packet type of an Initial packet + PacketTypeInitial PacketType = iota + // PacketTypeHandshake is the packet type of a Handshake packet + PacketTypeHandshake + // PacketTypeRetry is the packet type of a Retry packet + PacketTypeRetry + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT + // PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet + PacketTypeVersionNegotiation + // PacketType1RTT is a 1-RTT packet + PacketType1RTT + // PacketTypeStatelessReset is a stateless reset + PacketTypeStatelessReset + // PacketTypeNotDetermined is the packet type when it could not be determined + PacketTypeNotDetermined +) + +type PacketLossReason uint8 + +const ( + // PacketLossReorderingThreshold: when a packet is deemed lost due to reordering threshold + PacketLossReorderingThreshold PacketLossReason = iota + // PacketLossTimeThreshold: when a packet is deemed lost due to time threshold + PacketLossTimeThreshold +) + +type PacketDropReason uint8 + +const ( + // PacketDropKeyUnavailable is used when a packet is dropped because keys are unavailable + PacketDropKeyUnavailable PacketDropReason = iota + // PacketDropUnknownConnectionID is used when a packet is dropped because the connection ID is unknown + PacketDropUnknownConnectionID + // PacketDropHeaderParseError is used when a packet is dropped because header parsing failed + PacketDropHeaderParseError + // PacketDropPayloadDecryptError is used when a packet is dropped because decrypting the payload failed + PacketDropPayloadDecryptError + // PacketDropProtocolViolation is used when a packet is dropped due to a protocol violation + PacketDropProtocolViolation + // PacketDropDOSPrevention is used when a packet is dropped to mitigate a DoS attack + PacketDropDOSPrevention + // PacketDropUnsupportedVersion is used when a packet is dropped because the version is not supported + PacketDropUnsupportedVersion + // PacketDropUnexpectedPacket is used when an unexpected packet is received + PacketDropUnexpectedPacket + // PacketDropUnexpectedSourceConnectionID is used when a packet with an unexpected source connection ID is received + PacketDropUnexpectedSourceConnectionID + // PacketDropUnexpectedVersion is used when a packet with an unexpected version is received + PacketDropUnexpectedVersion + // PacketDropDuplicate is used when a duplicate packet is received + PacketDropDuplicate +) + +// TimerType is the type of the loss detection timer +type TimerType uint8 + +const ( + // TimerTypeACK is the timer type for the early retransmit timer + TimerTypeACK TimerType = iota + // TimerTypePTO is the timer type for the PTO retransmit timer + TimerTypePTO +) + +// TimeoutReason is the reason why a connection is closed +type TimeoutReason uint8 + +const ( + // TimeoutReasonHandshake is used when the connection is closed due to a handshake timeout + // This reason is not defined in the qlog draft, but very useful for debugging. + TimeoutReasonHandshake TimeoutReason = iota + // TimeoutReasonIdle is used when the connection is closed due to an idle timeout + // This reason is not defined in the qlog draft, but very useful for debugging. + TimeoutReasonIdle +) + +type CongestionState uint8 + +const ( + // CongestionStateSlowStart is the slow start phase of Reno / Cubic + CongestionStateSlowStart CongestionState = iota + // CongestionStateCongestionAvoidance is the slow start phase of Reno / Cubic + CongestionStateCongestionAvoidance + // CongestionStateRecovery is the recovery phase of Reno / Cubic + CongestionStateRecovery + // CongestionStateApplicationLimited means that the congestion controller is application limited + CongestionStateApplicationLimited +) diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/mockgen.go new file mode 100644 index 00000000..22c2c0e7 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mockgen.go @@ -0,0 +1,27 @@ +package quic + +//go:generate sh -c "./mockgen_private.sh quic mock_send_conn_test.go github.com/lucas-clemente/quic-go sendConn" +//go:generate sh -c "./mockgen_private.sh quic mock_sender_test.go github.com/lucas-clemente/quic-go sender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" +//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter" +//go:generate sh -c "./mockgen_private.sh quic mock_crypto_data_handler_test.go github.com/lucas-clemente/quic-go cryptoDataHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_frame_source_test.go github.com/lucas-clemente/quic-go frameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_ack_frame_source_test.go github.com/lucas-clemente/quic-go ackFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager" +//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager" +//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" +//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer" +//go:generate sh -c "./mockgen_private.sh quic mock_mtu_discoverer_test.go github.com/lucas-clemente/quic-go mtuDiscoverer" +//go:generate sh -c "./mockgen_private.sh quic mock_conn_runner_test.go github.com/lucas-clemente/quic-go connRunner" +//go:generate sh -c "./mockgen_private.sh quic mock_quic_conn_test.go github.com/lucas-clemente/quic-go quicConn" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager" +//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer" +//go:generate sh -c "./mockgen_private.sh quic mock_batch_conn_test.go github.com/lucas-clemente/quic-go batchConn" +//go:generate sh -c "mockgen -package quic -self_package github.com/lucas-clemente/quic-go -destination mock_token_store_test.go github.com/lucas-clemente/quic-go TokenStore" +//go:generate sh -c "mockgen -package quic -self_package github.com/lucas-clemente/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh b/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh new file mode 100644 index 00000000..92829d77 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +DEST=$2 +PACKAGE=$3 +TMPFILE="mockgen_tmp.go" +# uppercase the name of the interface +ORIG_INTERFACE_NAME=$4 +INTERFACE_NAME="$(tr '[:lower:]' '[:upper:]' <<< ${ORIG_INTERFACE_NAME:0:1})${ORIG_INTERFACE_NAME:1}" + +# Gather all files that contain interface definitions. +# These interfaces might be used as embedded interfaces, +# so we need to pass them to mockgen as aux_files. +AUX=() +for f in *.go; do + if [[ -z ${f##*_test.go} ]]; then + # skip test files + continue; + fi + if $(egrep -qe "type (.*) interface" $f); then + AUX+=("github.com/lucas-clemente/quic-go=$f") + fi +done + +# Find the file that defines the interface we're mocking. +for f in *.go; do + if [[ -z ${f##*_test.go} ]]; then + # skip test files + continue; + fi + INTERFACE=$(sed -n "/^type $ORIG_INTERFACE_NAME interface/,/^}/p" $f) + if [[ -n "$INTERFACE" ]]; then + SRC=$f + break + fi +done + +if [[ -z "$INTERFACE" ]]; then + echo "Interface $ORIG_INTERFACE_NAME not found." + exit 1 +fi + +AUX_FILES=$(IFS=, ; echo "${AUX[*]}") + +## create a public alias for the interface, so that mockgen can process it +echo -e "package $1\n" > $TMPFILE +echo "$INTERFACE" | sed "s/$ORIG_INTERFACE_NAME/$INTERFACE_NAME/" >> $TMPFILE +mockgen -package $1 -self_package $3 -destination $DEST -source=$TMPFILE -aux_files $AUX_FILES +sed "s/$TMPFILE/$SRC/" "$DEST" > "$DEST.new" && mv "$DEST.new" "$DEST" +rm "$TMPFILE" diff --git a/vendor/github.com/lucas-clemente/quic-go/mtu_discoverer.go b/vendor/github.com/lucas-clemente/quic-go/mtu_discoverer.go new file mode 100644 index 00000000..bf38eaac --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/mtu_discoverer.go @@ -0,0 +1,74 @@ +package quic + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type mtuDiscoverer interface { + ShouldSendProbe(now time.Time) bool + GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount) +} + +const ( + // At some point, we have to stop searching for a higher MTU. + // We're happy to send a packet that's 10 bytes smaller than the actual MTU. + maxMTUDiff = 20 + // send a probe packet every mtuProbeDelay RTTs + mtuProbeDelay = 5 +) + +type mtuFinder struct { + lastProbeTime time.Time + probeInFlight bool + mtuIncreased func(protocol.ByteCount) + + rttStats *utils.RTTStats + current protocol.ByteCount + max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) +} + +var _ mtuDiscoverer = &mtuFinder{} + +func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) mtuDiscoverer { + return &mtuFinder{ + current: start, + rttStats: rttStats, + lastProbeTime: time.Now(), // to make sure the first probe packet is not sent immediately + mtuIncreased: mtuIncreased, + max: max, + } +} + +func (f *mtuFinder) done() bool { + return f.max-f.current <= maxMTUDiff+1 +} + +func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { + if f.probeInFlight || f.done() { + return false + } + return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) +} + +func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { + size := (f.max + f.current) / 2 + f.lastProbeTime = time.Now() + f.probeInFlight = true + return ackhandler.Frame{ + Frame: &wire.PingFrame{}, + OnLost: func(wire.Frame) { + f.probeInFlight = false + f.max = size + }, + OnAcked: func(wire.Frame) { + f.probeInFlight = false + f.current = size + f.mtuIncreased(size) + }, + }, size +} diff --git a/vendor/github.com/lucas-clemente/quic-go/multiplexer.go b/vendor/github.com/lucas-clemente/quic-go/multiplexer.go new file mode 100644 index 00000000..2271b551 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/multiplexer.go @@ -0,0 +1,107 @@ +package quic + +import ( + "bytes" + "fmt" + "net" + "sync" + + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/logging" +) + +var ( + connMuxerOnce sync.Once + connMuxer multiplexer +) + +type indexableConn interface { + LocalAddr() net.Addr +} + +type multiplexer interface { + AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) + RemoveConn(indexableConn) error +} + +type connManager struct { + connIDLen int + statelessResetKey []byte + tracer logging.Tracer + manager packetHandlerManager +} + +// The connMultiplexer listens on multiple net.PacketConns and dispatches +// incoming packets to the connection handler. +type connMultiplexer struct { + mutex sync.Mutex + + conns map[string] /* LocalAddr().String() */ connManager + newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests + + logger utils.Logger +} + +var _ multiplexer = &connMultiplexer{} + +func getMultiplexer() multiplexer { + connMuxerOnce.Do(func() { + connMuxer = &connMultiplexer{ + conns: make(map[string]connManager), + logger: utils.DefaultLogger.WithPrefix("muxer"), + newPacketHandlerManager: newPacketHandlerMap, + } + }) + return connMuxer +} + +func (m *connMultiplexer) AddConn( + c net.PacketConn, + connIDLen int, + statelessResetKey []byte, + tracer logging.Tracer, +) (packetHandlerManager, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + addr := c.LocalAddr() + connIndex := addr.Network() + " " + addr.String() + p, ok := m.conns[connIndex] + if !ok { + manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) + if err != nil { + return nil, err + } + p = connManager{ + connIDLen: connIDLen, + statelessResetKey: statelessResetKey, + manager: manager, + tracer: tracer, + } + m.conns[connIndex] = p + } else { + if p.connIDLen != connIDLen { + return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) + } + if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { + return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") + } + if tracer != p.tracer { + return nil, fmt.Errorf("cannot use different tracers on the same packet conn") + } + } + return p.manager, nil +} + +func (m *connMultiplexer) RemoveConn(c indexableConn) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() + if _, ok := m.conns[connIndex]; !ok { + return fmt.Errorf("cannote remove connection, connection is unknown") + } + + delete(m.conns, connIndex) + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go b/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go new file mode 100644 index 00000000..2d55a95e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go @@ -0,0 +1,489 @@ +package quic + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "hash" + "io" + "log" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/logging" +) + +type zeroRTTQueue struct { + queue []*receivedPacket + retireTimer *time.Timer +} + +var _ packetHandler = &zeroRTTQueue{} + +func (h *zeroRTTQueue) handlePacket(p *receivedPacket) { + if len(h.queue) < protocol.Max0RTTQueueLen { + h.queue = append(h.queue, p) + } +} +func (h *zeroRTTQueue) shutdown() {} +func (h *zeroRTTQueue) destroy(error) {} +func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } +func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) { + for _, p := range h.queue { + sess.handlePacket(p) + } +} + +func (h *zeroRTTQueue) Clear() { + for _, p := range h.queue { + p.buffer.Release() + } +} + +// rawConn is a connection that allow reading of a receivedPacket. +type rawConn interface { + ReadPacket() (*receivedPacket, error) + WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + LocalAddr() net.Addr + io.Closer +} + +type packetHandlerMapEntry struct { + packetHandler packetHandler + is0RTTQueue bool +} + +// The packetHandlerMap stores packetHandlers, identified by connection ID. +// It is used: +// * by the server to store connections +// * when multiplexing outgoing connections to store clients +type packetHandlerMap struct { + mutex sync.Mutex + + conn rawConn + connIDLen int + + handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry + resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler + server unknownPacketHandler + numZeroRTTEntries int + + listening chan struct{} // is closed when listen returns + closed bool + + deleteRetiredConnsAfter time.Duration + zeroRTTQueueDuration time.Duration + + statelessResetEnabled bool + statelessResetMutex sync.Mutex + statelessResetHasher hash.Hash + + tracer logging.Tracer + logger utils.Logger +} + +var _ packetHandlerManager = &packetHandlerMap{} + +func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { + conn, ok := c.(interface{ SetReadBuffer(int) error }) + if !ok { + return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") + } + size, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if size >= protocol.DesiredReceiveBufferSize { + logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) + return nil + } + if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { + return fmt.Errorf("failed to increase receive buffer size: %w", err) + } + newSize, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if newSize == size { + return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + if newSize < protocol.DesiredReceiveBufferSize { + return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) + return nil +} + +// only print warnings about the UDP receive buffer size once +var receiveBufferWarningOnce sync.Once + +func newPacketHandlerMap( + c net.PacketConn, + connIDLen int, + statelessResetKey []byte, + tracer logging.Tracer, + logger utils.Logger, +) (packetHandlerManager, error) { + if err := setReceiveBuffer(c, logger); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + receiveBufferWarningOnce.Do(func() { + if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { + return + } + log.Printf("%s. See https://github.com/lucas-clemente/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) + }) + } + } + conn, err := wrapConn(c) + if err != nil { + return nil, err + } + m := &packetHandlerMap{ + conn: conn, + connIDLen: connIDLen, + listening: make(chan struct{}), + handlers: make(map[string]packetHandlerMapEntry), + resetTokens: make(map[protocol.StatelessResetToken]packetHandler), + deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, + zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, + statelessResetEnabled: len(statelessResetKey) > 0, + statelessResetHasher: hmac.New(sha256.New, statelessResetKey), + tracer: tracer, + logger: logger, + } + go m.listen() + + if logger.Debug() { + go m.logUsage() + } + return m, nil +} + +func (h *packetHandlerMap) logUsage() { + ticker := time.NewTicker(2 * time.Second) + var printedZero bool + for { + select { + case <-h.listening: + return + case <-ticker.C: + } + + h.mutex.Lock() + numHandlers := len(h.handlers) + numTokens := len(h.resetTokens) + h.mutex.Unlock() + // If the number tracked handlers and tokens is zero, only print it a single time. + hasZero := numHandlers == 0 && numTokens == 0 + if !hasZero || (hasZero && !printedZero) { + h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens) + printedZero = false + if hasZero { + printedZero = true + } + } + } +} + +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { + h.mutex.Lock() + defer h.mutex.Unlock() + + if _, ok := h.handlers[string(id)]; ok { + h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) + return false + } + h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} + h.logger.Debugf("Adding connection ID %s.", id) + return true +} + +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { + h.mutex.Lock() + defer h.mutex.Unlock() + + var q *zeroRTTQueue + if entry, ok := h.handlers[string(clientDestConnID)]; ok { + if !entry.is0RTTQueue { + h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) + return false + } + q = entry.packetHandler.(*zeroRTTQueue) + q.retireTimer.Stop() + h.numZeroRTTEntries-- + if h.numZeroRTTEntries < 0 { + panic("number of 0-RTT queues < 0") + } + } + sess := fn() + if q != nil { + q.EnqueueAll(sess) + } + h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} + h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} + h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) + return true +} + +func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.mutex.Lock() + delete(h.handlers, string(id)) + h.mutex.Unlock() + h.logger.Debugf("Removing connection ID %s.", id) +} + +func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { + h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) + time.AfterFunc(h.deleteRetiredConnsAfter, func() { + h.mutex.Lock() + delete(h.handlers, string(id)) + h.mutex.Unlock() + h.logger.Debugf("Removing connection ID %s after it has been retired.", id) + }) +} + +func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { + h.mutex.Lock() + h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} + h.mutex.Unlock() + h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id) + + time.AfterFunc(h.deleteRetiredConnsAfter, func() { + h.mutex.Lock() + handler.shutdown() + delete(h.handlers, string(id)) + h.mutex.Unlock() + h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id) + }) +} + +func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { + h.mutex.Lock() + h.resetTokens[token] = handler + h.mutex.Unlock() +} + +func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { + h.mutex.Lock() + delete(h.resetTokens, token) + h.mutex.Unlock() +} + +func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { + h.mutex.Lock() + h.server = s + h.mutex.Unlock() +} + +func (h *packetHandlerMap) CloseServer() { + h.mutex.Lock() + if h.server == nil { + h.mutex.Unlock() + return + } + h.server = nil + var wg sync.WaitGroup + for _, entry := range h.handlers { + if entry.packetHandler.getPerspective() == protocol.PerspectiveServer { + wg.Add(1) + go func(handler packetHandler) { + // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + handler.shutdown() + wg.Done() + }(entry.packetHandler) + } + } + h.mutex.Unlock() + wg.Wait() +} + +// Destroy closes the underlying connection and waits until listen() has returned. +// It does not close active connections. +func (h *packetHandlerMap) Destroy() error { + if err := h.conn.Close(); err != nil { + return err + } + <-h.listening // wait until listening returns + return nil +} + +func (h *packetHandlerMap) close(e error) error { + h.mutex.Lock() + if h.closed { + h.mutex.Unlock() + return nil + } + + var wg sync.WaitGroup + for _, entry := range h.handlers { + wg.Add(1) + go func(handler packetHandler) { + handler.destroy(e) + wg.Done() + }(entry.packetHandler) + } + + if h.server != nil { + h.server.setCloseError(e) + } + h.closed = true + h.mutex.Unlock() + wg.Wait() + return getMultiplexer().RemoveConn(h.conn) +} + +func (h *packetHandlerMap) listen() { + defer close(h.listening) + for { + p, err := h.conn.ReadPacket() + //nolint:staticcheck // SA1019 ignore this! + // TODO: This code is used to ignore wsa errors on Windows. + // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. + // See https://github.com/lucas-clemente/quic-go/issues/1737 for details. + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + h.logger.Debugf("Temporary error reading from conn: %w", err) + continue + } + if err != nil { + h.close(err) + return + } + h.handlePacket(p) + } +} + +func (h *packetHandlerMap) handlePacket(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, h.connIDLen) + if err != nil { + h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) + if h.tracer != nil { + h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + p.buffer.MaybeRelease() + return + } + + h.mutex.Lock() + defer h.mutex.Unlock() + + if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { + return + } + + if entry, ok := h.handlers[string(connID)]; ok { + if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue + if wire.Is0RTTPacket(p.data) { + entry.packetHandler.handlePacket(p) + return + } + } else { // existing connection + entry.packetHandler.handlePacket(p) + return + } + } + if p.data[0]&0x80 == 0 { + go h.maybeSendStatelessReset(p, connID) + return + } + if h.server == nil { // no server set + h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + return + } + if wire.Is0RTTPacket(p.data) { + if h.numZeroRTTEntries >= protocol.Max0RTTQueues { + return + } + h.numZeroRTTEntries++ + queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} + h.handlers[string(connID)] = packetHandlerMapEntry{ + packetHandler: queue, + is0RTTQueue: true, + } + queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { + h.mutex.Lock() + defer h.mutex.Unlock() + // The entry might have been replaced by an actual connection. + // Only delete it if it's still a 0-RTT queue. + if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { + delete(h.handlers, string(connID)) + h.numZeroRTTEntries-- + if h.numZeroRTTEntries < 0 { + panic("number of 0-RTT queues < 0") + } + entry.packetHandler.(*zeroRTTQueue).Clear() + if h.logger.Debug() { + h.logger.Debugf("Removing 0-RTT queue for %s.", connID) + } + } + }) + queue.handlePacket(p) + return + } + h.server.handlePacket(p) +} + +func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { + // stateless resets are always short header packets + if data[0]&0x80 != 0 { + return false + } + if len(data) < 17 /* type byte + 16 bytes for the reset token */ { + return false + } + + var token protocol.StatelessResetToken + copy(token[:], data[len(data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) + go sess.destroy(&StatelessResetError{Token: token}) + return true + } + return false +} + +func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { + var token protocol.StatelessResetToken + if !h.statelessResetEnabled { + // Return a random stateless reset token. + // This token will be sent in the server's transport parameters. + // By using a random token, an off-path attacker won't be able to disrupt the connection. + rand.Read(token[:]) + return token + } + h.statelessResetMutex.Lock() + h.statelessResetHasher.Write(connID.Bytes()) + copy(token[:], h.statelessResetHasher.Sum(nil)) + h.statelessResetHasher.Reset() + h.statelessResetMutex.Unlock() + return token +} + +func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { + defer p.buffer.Release() + if !h.statelessResetEnabled { + return + } + // Don't send a stateless reset in response to very small packets. + // This includes packets that could be stateless resets. + if len(p.data) <= protocol.MinStatelessResetSize { + return + } + token := h.GetStatelessResetToken(connID) + h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) + data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) + rand.Read(data) + data[0] = (data[0] & 0x7f) | 0x40 + data = append(data, token[:]...) + if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + h.logger.Debugf("Error sending Stateless Reset: %s", err) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go new file mode 100644 index 00000000..1d037ab2 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go @@ -0,0 +1,894 @@ +package quic + +import ( + "bytes" + "errors" + "fmt" + "net" + "time" + + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type packer interface { + PackCoalescedPacket() (*coalescedPacket, error) + PackPacket() (*packedPacket, error) + MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) + MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) + PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error) + PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error) + + SetMaxPacketSize(protocol.ByteCount) + PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) + + HandleTransportParameters(*wire.TransportParameters) + SetToken([]byte) +} + +type sealer interface { + handshake.LongHeaderSealer +} + +type payload struct { + frames []ackhandler.Frame + ack *wire.AckFrame + length protocol.ByteCount +} + +type packedPacket struct { + buffer *packetBuffer + *packetContents +} + +type packetContents struct { + header *wire.ExtendedHeader + ack *wire.AckFrame + frames []ackhandler.Frame + + length protocol.ByteCount + + isMTUProbePacket bool +} + +type coalescedPacket struct { + buffer *packetBuffer + packets []*packetContents +} + +func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel { + if !p.header.IsLongHeader { + return protocol.Encryption1RTT + } + //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). + switch p.header.Type { + case protocol.PacketTypeInitial: + return protocol.EncryptionInitial + case protocol.PacketTypeHandshake: + return protocol.EncryptionHandshake + case protocol.PacketType0RTT: + return protocol.Encryption0RTT + default: + panic("can't determine encryption level") + } +} + +func (p *packetContents) IsAckEliciting() bool { + return ackhandler.HasAckElicitingFrames(p.frames) +} + +func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { + largestAcked := protocol.InvalidPacketNumber + if p.ack != nil { + largestAcked = p.ack.LargestAcked() + } + encLevel := p.EncryptionLevel() + for i := range p.frames { + if p.frames[i].OnLost != nil { + continue + } + switch encLevel { + case protocol.EncryptionInitial: + p.frames[i].OnLost = q.AddInitial + case protocol.EncryptionHandshake: + p.frames[i].OnLost = q.AddHandshake + case protocol.Encryption0RTT, protocol.Encryption1RTT: + p.frames[i].OnLost = q.AddAppData + } + } + return &ackhandler.Packet{ + PacketNumber: p.header.PacketNumber, + LargestAcked: largestAcked, + Frames: p.frames, + Length: p.length, + EncryptionLevel: encLevel, + SendTime: now, + IsPathMTUProbePacket: p.isMTUProbePacket, + } +} + +func getMaxPacketSize(addr net.Addr) protocol.ByteCount { + maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if utils.IsIPv4(udpAddr.IP) { + maxSize = protocol.InitialPacketSizeIPv4 + } else { + maxSize = protocol.InitialPacketSizeIPv6 + } + } + return maxSize +} + +type packetNumberManager interface { + PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) + PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber +} + +type sealingManager interface { + GetInitialSealer() (handshake.LongHeaderSealer, error) + GetHandshakeSealer() (handshake.LongHeaderSealer, error) + Get0RTTSealer() (handshake.LongHeaderSealer, error) + Get1RTTSealer() (handshake.ShortHeaderSealer, error) +} + +type frameSource interface { + HasData() bool + AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) +} + +type ackFrameSource interface { + GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame +} + +type packetPacker struct { + srcConnID protocol.ConnectionID + getDestConnID func() protocol.ConnectionID + + perspective protocol.Perspective + version protocol.VersionNumber + cryptoSetup sealingManager + + initialStream cryptoStream + handshakeStream cryptoStream + + token []byte + + pnManager packetNumberManager + framer frameSource + acks ackFrameSource + datagramQueue *datagramQueue + retransmissionQueue *retransmissionQueue + + maxPacketSize protocol.ByteCount + numNonAckElicitingAcks int +} + +var _ packer = &packetPacker{} + +func newPacketPacker( + srcConnID protocol.ConnectionID, + getDestConnID func() protocol.ConnectionID, + initialStream cryptoStream, + handshakeStream cryptoStream, + packetNumberManager packetNumberManager, + retransmissionQueue *retransmissionQueue, + remoteAddr net.Addr, // only used for determining the max packet size + cryptoSetup sealingManager, + framer frameSource, + acks ackFrameSource, + datagramQueue *datagramQueue, + perspective protocol.Perspective, + version protocol.VersionNumber, +) *packetPacker { + return &packetPacker{ + cryptoSetup: cryptoSetup, + getDestConnID: getDestConnID, + srcConnID: srcConnID, + initialStream: initialStream, + handshakeStream: handshakeStream, + retransmissionQueue: retransmissionQueue, + datagramQueue: datagramQueue, + perspective: perspective, + version: version, + framer: framer, + acks: acks, + pnManager: packetNumberManager, + maxPacketSize: getMaxPacketSize(remoteAddr), + } +} + +// PackConnectionClose packs a packet that closes the connection with a transport error. +func (p *packetPacker) PackConnectionClose(e *qerr.TransportError) (*coalescedPacket, error) { + var reason string + // don't send details of crypto errors + if !e.ErrorCode.IsCryptoError() { + reason = e.ErrorMessage + } + return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason) +} + +// PackApplicationClose packs a packet that closes the connection with an application error. +func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError) (*coalescedPacket, error) { + return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage) +} + +func (p *packetPacker) packConnectionClose( + isApplicationError bool, + errorCode uint64, + frameType uint64, + reason string, +) (*coalescedPacket, error) { + var sealers [4]sealer + var hdrs [4]*wire.ExtendedHeader + var payloads [4]*payload + var size protocol.ByteCount + var numPackets uint8 + encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} + for i, encLevel := range encLevels { + if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT { + continue + } + ccf := &wire.ConnectionCloseFrame{ + IsApplicationError: isApplicationError, + ErrorCode: errorCode, + FrameType: frameType, + ReasonPhrase: reason, + } + // don't send application errors in Initial or Handshake packets + if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) { + ccf.IsApplicationError = false + ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode) + ccf.ReasonPhrase = "" + } + payload := &payload{ + frames: []ackhandler.Frame{{Frame: ccf}}, + length: ccf.Length(p.version), + } + + var sealer sealer + var err error + var keyPhase protocol.KeyPhaseBit // only set for 1-RTT + switch encLevel { + case protocol.EncryptionInitial: + sealer, err = p.cryptoSetup.GetInitialSealer() + case protocol.EncryptionHandshake: + sealer, err = p.cryptoSetup.GetHandshakeSealer() + case protocol.Encryption0RTT: + sealer, err = p.cryptoSetup.Get0RTTSealer() + case protocol.Encryption1RTT: + var s handshake.ShortHeaderSealer + s, err = p.cryptoSetup.Get1RTTSealer() + if err == nil { + keyPhase = s.KeyPhase() + } + sealer = s + } + if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped { + continue + } + if err != nil { + return nil, err + } + sealers[i] = sealer + var hdr *wire.ExtendedHeader + if encLevel == protocol.Encryption1RTT { + hdr = p.getShortHeader(keyPhase) + } else { + hdr = p.getLongHeader(encLevel) + } + hdrs[i] = hdr + payloads[i] = payload + size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) + numPackets++ + } + contents := make([]*packetContents, 0, numPackets) + buffer := getPacketBuffer() + for i, encLevel := range encLevels { + if sealers[i] == nil { + continue + } + var paddingLen protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + paddingLen = p.initialPaddingLen(payloads[i].frames, size) + } + c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false) + if err != nil { + return nil, err + } + contents = append(contents, c) + } + return &coalescedPacket{buffer: buffer, packets: contents}, nil +} + +// packetLength calculates the length of the serialized packet. +// It takes into account that packets that have a tiny payload need to be padded, +// such that len(payload) + packet number len >= 4 + AEAD overhead +func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount { + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(hdr.PacketNumberLen) + if payload.length < 4-pnLen { + paddingLen = 4 - pnLen - payload.length + } + return hdr.GetLength(p.version) + payload.length + paddingLen +} + +func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { + var encLevel protocol.EncryptionLevel + var ack *wire.AckFrame + if !handshakeConfirmed { + ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) + if ack != nil { + encLevel = protocol.EncryptionInitial + } else { + ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true) + if ack != nil { + encLevel = protocol.EncryptionHandshake + } + } + } + if ack == nil { + ack = p.acks.GetAckFrame(protocol.Encryption1RTT, true) + if ack == nil { + return nil, nil + } + encLevel = protocol.Encryption1RTT + } + payload := &payload{ + ack: ack, + length: ack.Length(p.version), + } + + sealer, hdr, err := p.getSealerAndHeader(encLevel) + if err != nil { + return nil, err + } + return p.writeSinglePacket(hdr, payload, encLevel, sealer) +} + +// size is the expected size of the packet, if no padding was applied. +func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { + // For the server, only ack-eliciting Initial packets need to be padded. + if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { + return 0 + } + if size >= p.maxPacketSize { + return 0 + } + return p.maxPacketSize - size +} + +// PackCoalescedPacket packs a new packet. +// It packs an Initial / Handshake if there is data to send in these packet number spaces. +// It should only be called before the handshake is confirmed. +func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { + maxPacketSize := p.maxPacketSize + if p.perspective == protocol.PerspectiveClient { + maxPacketSize = protocol.MinInitialPacketSize + } + var initialHdr, handshakeHdr, appDataHdr *wire.ExtendedHeader + var initialPayload, handshakePayload, appDataPayload *payload + var numPackets int + // Try packing an Initial packet. + initialSealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil && err != handshake.ErrKeysDropped { + return nil, err + } + var size protocol.ByteCount + if initialSealer != nil { + initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial) + if initialPayload != nil { + size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) + numPackets++ + } + } + + // Add a Handshake packet. + var handshakeSealer sealer + if size < maxPacketSize-protocol.MinCoalescedPacketSize { + var err error + handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if handshakeSealer != nil { + handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake) + if handshakePayload != nil { + s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) + size += s + numPackets++ + } + } + } + + // Add a 0-RTT / 1-RTT packet. + var appDataSealer sealer + appDataEncLevel := protocol.Encryption1RTT + if size < maxPacketSize-protocol.MinCoalescedPacketSize { + var err error + appDataSealer, appDataHdr, appDataPayload = p.maybeGetAppDataPacket(maxPacketSize-size, size) + if err != nil { + return nil, err + } + if appDataHdr != nil { + if appDataHdr.IsLongHeader { + appDataEncLevel = protocol.Encryption0RTT + } + if appDataPayload != nil { + size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead()) + numPackets++ + } + } + } + + if numPackets == 0 { + return nil, nil + } + + buffer := getPacketBuffer() + packet := &coalescedPacket{ + buffer: buffer, + packets: make([]*packetContents, 0, numPackets), + } + if initialPayload != nil { + padding := p.initialPaddingLen(initialPayload.frames, size) + cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false) + if err != nil { + return nil, err + } + packet.packets = append(packet.packets, cont) + } + if handshakePayload != nil { + cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false) + if err != nil { + return nil, err + } + packet.packets = append(packet.packets, cont) + } + if appDataPayload != nil { + cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false) + if err != nil { + return nil, err + } + packet.packets = append(packet.packets, cont) + } + return packet, nil +} + +// PackPacket packs a packet in the application data packet number space. +// It should be called after the handshake is confirmed. +func (p *packetPacker) PackPacket() (*packedPacket, error) { + sealer, hdr, payload := p.maybeGetAppDataPacket(p.maxPacketSize, 0) + if payload == nil { + return nil, nil + } + buffer := getPacketBuffer() + encLevel := protocol.Encryption1RTT + if hdr.IsLongHeader { + encLevel = protocol.Encryption0RTT + } + cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer, false) + if err != nil { + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: cont, + }, nil +} + +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { + var s cryptoStream + var hasRetransmission bool + //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. + switch encLevel { + case protocol.EncryptionInitial: + s = p.initialStream + hasRetransmission = p.retransmissionQueue.HasInitialData() + case protocol.EncryptionHandshake: + s = p.handshakeStream + hasRetransmission = p.retransmissionQueue.HasHandshakeData() + } + + hasData := s.HasData() + var ack *wire.AckFrame + if encLevel == protocol.EncryptionInitial || currentSize == 0 { + ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) + } + if !hasData && !hasRetransmission && ack == nil { + // nothing to send + return nil, nil + } + + var payload payload + if ack != nil { + payload.ack = ack + payload.length = ack.Length(p.version) + maxPacketSize -= payload.length + } + hdr := p.getLongHeader(encLevel) + maxPacketSize -= hdr.GetLength(p.version) + if hasRetransmission { + for { + var f wire.Frame + //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s + switch encLevel { + case protocol.EncryptionInitial: + f = p.retransmissionQueue.GetInitialFrame(maxPacketSize) + case protocol.EncryptionHandshake: + f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize) + } + if f == nil { + break + } + payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) + frameLen := f.Length(p.version) + payload.length += frameLen + maxPacketSize -= frameLen + } + } else if s.HasData() { + cf := s.PopCryptoFrame(maxPacketSize) + payload.frames = []ackhandler.Frame{{Frame: cf}} + payload.length += cf.Length(p.version) + } + return hdr, &payload +} + +func (p *packetPacker) maybeGetAppDataPacket(maxPacketSize, currentSize protocol.ByteCount) (sealer, *wire.ExtendedHeader, *payload) { + var sealer sealer + var encLevel protocol.EncryptionLevel + var hdr *wire.ExtendedHeader + oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() + if err == nil { + encLevel = protocol.Encryption1RTT + sealer = oneRTTSealer + hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) + } else { + // 1-RTT sealer not yet available + if p.perspective != protocol.PerspectiveClient { + return nil, nil, nil + } + sealer, err = p.cryptoSetup.Get0RTTSealer() + if sealer == nil || err != nil { + return nil, nil, nil + } + encLevel = protocol.Encryption0RTT + hdr = p.getLongHeader(protocol.Encryption0RTT) + } + + maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) + payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0) + return sealer, hdr, payload +} + +func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { + payload := p.composeNextPacket(maxPayloadSize, ackAllowed) + + // check if we have anything to send + if len(payload.frames) == 0 { + if payload.ack == nil { + return nil + } + // the packet only contains an ACK + if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { + ping := &wire.PingFrame{} + // don't retransmit the PING frame when it is lost + payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping, OnLost: func(wire.Frame) {}}) + payload.length += ping.Length(p.version) + p.numNonAckElicitingAcks = 0 + } else { + p.numNonAckElicitingAcks++ + } + } else { + p.numNonAckElicitingAcks = 0 + } + return payload +} + +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { + payload := &payload{frames: make([]ackhandler.Frame, 0, 1)} + + var hasDatagram bool + if p.datagramQueue != nil { + if datagram := p.datagramQueue.Get(); datagram != nil { + payload.frames = append(payload.frames, ackhandler.Frame{ + Frame: datagram, + // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. + OnLost: func(wire.Frame) {}, + }) + payload.length += datagram.Length(p.version) + hasDatagram = true + } + } + + var ack *wire.AckFrame + hasData := p.framer.HasData() + hasRetransmission := p.retransmissionQueue.HasAppData() + // TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued + if !hasDatagram && ackAllowed { + ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData) + if ack != nil { + payload.ack = ack + payload.length += ack.Length(p.version) + } + } + + if ack == nil && !hasData && !hasRetransmission { + return payload + } + + if hasRetransmission { + for { + remainingLen := maxFrameSize - payload.length + if remainingLen < protocol.MinStreamFrameSize { + break + } + f := p.retransmissionQueue.GetAppDataFrame(remainingLen) + if f == nil { + break + } + payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) + payload.length += f.Length(p.version) + } + } + + if hasData { + var lengthAdded protocol.ByteCount + payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) + payload.length += lengthAdded + + payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length) + payload.length += lengthAdded + } + return payload +} + +func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) { + var hdr *wire.ExtendedHeader + var payload *payload + var sealer sealer + //nolint:exhaustive // Probe packets are never sent for 0-RTT. + switch encLevel { + case protocol.EncryptionInitial: + var err error + sealer, err = p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, err + } + hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial) + case protocol.EncryptionHandshake: + var err error + sealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, err + } + hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake) + case protocol.Encryption1RTT: + oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + sealer = oneRTTSealer + hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) + payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) + default: + panic("unknown encryption level") + } + if payload == nil { + return nil, nil + } + size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) + var padding protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + padding = p.initialPaddingLen(payload.frames, size) + } + buffer := getPacketBuffer() + cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false) + if err != nil { + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: cont, + }, nil +} + +func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) { + payload := &payload{ + frames: []ackhandler.Frame{ping}, + length: ping.Length(p.version), + } + buffer := getPacketBuffer() + sealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + hdr := p.getShortHeader(sealer.KeyPhase()) + padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) + contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true) + if err != nil { + return nil, err + } + contents.isMTUProbePacket = true + return &packedPacket{ + buffer: buffer, + packetContents: contents, + }, nil +} + +func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { + switch encLevel { + case protocol.EncryptionInitial: + sealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.EncryptionInitial) + return sealer, hdr, nil + case protocol.Encryption0RTT: + sealer, err := p.cryptoSetup.Get0RTTSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.Encryption0RTT) + return sealer, hdr, nil + case protocol.EncryptionHandshake: + sealer, err := p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getLongHeader(protocol.EncryptionHandshake) + return sealer, hdr, nil + case protocol.Encryption1RTT: + sealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, nil, err + } + hdr := p.getShortHeader(sealer.KeyPhase()) + return sealer, hdr, nil + default: + return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel) + } +} + +func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdr := &wire.ExtendedHeader{} + hdr.PacketNumber = pn + hdr.PacketNumberLen = pnLen + hdr.DestConnectionID = p.getDestConnID() + hdr.KeyPhase = kp + return hdr +} + +func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { + pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) + hdr := &wire.ExtendedHeader{ + PacketNumber: pn, + PacketNumberLen: pnLen, + } + hdr.IsLongHeader = true + hdr.Version = p.version + hdr.SrcConnectionID = p.srcConnID + hdr.DestConnectionID = p.getDestConnID() + + //nolint:exhaustive // 1-RTT packets are not long header packets. + switch encLevel { + case protocol.EncryptionInitial: + hdr.Type = protocol.PacketTypeInitial + hdr.Token = p.token + case protocol.EncryptionHandshake: + hdr.Type = protocol.PacketTypeHandshake + case protocol.Encryption0RTT: + hdr.Type = protocol.PacketType0RTT + } + return hdr +} + +// writeSinglePacket packs a single packet. +func (p *packetPacker) writeSinglePacket( + hdr *wire.ExtendedHeader, + payload *payload, + encLevel protocol.EncryptionLevel, + sealer sealer, +) (*packedPacket, error) { + buffer := getPacketBuffer() + var paddingLen protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + paddingLen = p.initialPaddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead())) + } + contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer, false) + if err != nil { + return nil, err + } + return &packedPacket{ + buffer: buffer, + packetContents: contents, + }, nil +} + +func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(header.PacketNumberLen) + if payload.length < 4-pnLen { + paddingLen = 4 - pnLen - payload.length + } + paddingLen += padding + if header.IsLongHeader { + header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen + } + + hdrOffset := buffer.Len() + buf := bytes.NewBuffer(buffer.Data) + if err := header.Write(buf, p.version); err != nil { + return nil, err + } + payloadOffset := buf.Len() + + if payload.ack != nil { + if err := payload.ack.Write(buf, p.version); err != nil { + return nil, err + } + } + if paddingLen > 0 { + buf.Write(make([]byte, paddingLen)) + } + for _, frame := range payload.frames { + if err := frame.Write(buf, p.version); err != nil { + return nil, err + } + } + + if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { + return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) + } + if !isMTUProbePacket { + if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { + return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + } + } + + raw := buffer.Data + // encrypt the packet + raw = raw[:buf.Len()] + _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset]) + raw = raw[0 : buf.Len()+sealer.Overhead()] + // apply header protection + pnOffset := payloadOffset - int(header.PacketNumberLen) + sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[hdrOffset], raw[pnOffset:payloadOffset]) + buffer.Data = raw + + num := p.pnManager.PopPacketNumber(encLevel) + if num != header.PacketNumber { + return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + } + return &packetContents{ + header: header, + ack: payload.ack, + frames: payload.frames, + length: buffer.Len() - hdrOffset, + }, nil +} + +func (p *packetPacker) SetToken(token []byte) { + p.token = token +} + +// When a higher MTU is discovered, use it. +func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) { + p.maxPacketSize = s +} + +// If the peer sets a max_packet_size that's smaller than the size we're currently using, +// we need to reduce the size of packets we send. +func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) { + if params.MaxUDPPayloadSize != 0 { + p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxUDPPayloadSize) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go new file mode 100644 index 00000000..f70d8d07 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go @@ -0,0 +1,196 @@ +package quic + +import ( + "bytes" + "fmt" + "time" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type headerDecryptor interface { + DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) +} + +type headerParseError struct { + err error +} + +func (e *headerParseError) Unwrap() error { + return e.err +} + +func (e *headerParseError) Error() string { + return e.err.Error() +} + +type unpackedPacket struct { + packetNumber protocol.PacketNumber // the decoded packet number + hdr *wire.ExtendedHeader + encryptionLevel protocol.EncryptionLevel + data []byte +} + +// The packetUnpacker unpacks QUIC packets. +type packetUnpacker struct { + cs handshake.CryptoSetup + + version protocol.VersionNumber +} + +var _ unpacker = &packetUnpacker{} + +func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker { + return &packetUnpacker{ + cs: cs, + version: version, + } +} + +// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. +// If any other error occurred when parsing the header, the error is of type headerParseError. +// If decrypting the payload fails for any reason, the error is the error returned by the AEAD. +func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + var encLevel protocol.EncryptionLevel + var extHdr *wire.ExtendedHeader + var decrypted []byte + //nolint:exhaustive // Retry packets can't be unpacked. + switch hdr.Type { + case protocol.PacketTypeInitial: + encLevel = protocol.EncryptionInitial + opener, err := u.cs.GetInitialOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } + case protocol.PacketTypeHandshake: + encLevel = protocol.EncryptionHandshake + opener, err := u.cs.GetHandshakeOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } + case protocol.PacketType0RTT: + encLevel = protocol.Encryption0RTT + opener, err := u.cs.Get0RTTOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + if err != nil { + return nil, err + } + default: + if hdr.IsLongHeader { + return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) + } + encLevel = protocol.Encryption1RTT + opener, err := u.cs.Get1RTTOpener() + if err != nil { + return nil, err + } + extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data) + if err != nil { + return nil, err + } + } + + return &unpackedPacket{ + hdr: extHdr, + packetNumber: extHdr.PacketNumber, + encryptionLevel: encLevel, + data: decrypted, + }, nil +} + +func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { + extHdr, parseErr := u.unpackHeader(opener, hdr, data) + // If the reserved bits are set incorrectly, we still need to continue unpacking. + // This avoids a timing side-channel, which otherwise might allow an attacker + // to gain information about the header encryption. + if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { + return nil, nil, parseErr + } + extHdrLen := extHdr.ParsedLen() + extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) + decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) + if err != nil { + return nil, nil, err + } + if parseErr != nil { + return nil, nil, parseErr + } + return extHdr, decrypted, nil +} + +func (u *packetUnpacker) unpackShortHeaderPacket( + opener handshake.ShortHeaderOpener, + hdr *wire.Header, + rcvTime time.Time, + data []byte, +) (*wire.ExtendedHeader, []byte, error) { + extHdr, parseErr := u.unpackHeader(opener, hdr, data) + // If the reserved bits are set incorrectly, we still need to continue unpacking. + // This avoids a timing side-channel, which otherwise might allow an attacker + // to gain information about the header encryption. + if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { + return nil, nil, parseErr + } + extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) + extHdrLen := extHdr.ParsedLen() + decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) + if err != nil { + return nil, nil, err + } + if parseErr != nil { + return nil, nil, parseErr + } + return extHdr, decrypted, nil +} + +// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. +func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { + extHdr, err := unpackHeader(hd, hdr, data, u.version) + if err != nil && err != wire.ErrInvalidReservedBits { + return nil, &headerParseError{err: err} + } + return extHdr, err +} + +func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { + r := bytes.NewReader(data) + + hdrLen := hdr.ParsedLen() + if protocol.ByteCount(len(data)) < hdrLen+4+16 { + //nolint:stylecheck + return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen) + } + // The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it. + // 1. save a copy of the 4 bytes + origPNBytes := make([]byte, 4) + copy(origPNBytes, data[hdrLen:hdrLen+4]) + // 2. decrypt the header, assuming a 4 byte packet number + hd.DecryptHeader( + data[hdrLen+4:hdrLen+4+16], + &data[0], + data[hdrLen:hdrLen+4], + ) + // 3. parse the header (and learn the actual length of the packet number) + extHdr, parseErr := hdr.ParseExtended(r, version) + if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { + return nil, parseErr + } + // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier + if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { + copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) + } + return extHdr, parseErr +} diff --git a/vendor/github.com/lucas-clemente/quic-go/quicvarint/io.go b/vendor/github.com/lucas-clemente/quic-go/quicvarint/io.go new file mode 100644 index 00000000..9368d1c5 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/quicvarint/io.go @@ -0,0 +1,68 @@ +package quicvarint + +import ( + "bytes" + "io" +) + +// Reader implements both the io.ByteReader and io.Reader interfaces. +type Reader interface { + io.ByteReader + io.Reader +} + +var _ Reader = &bytes.Reader{} + +type byteReader struct { + io.Reader +} + +var _ Reader = &byteReader{} + +// NewReader returns a Reader for r. +// If r already implements both io.ByteReader and io.Reader, NewReader returns r. +// Otherwise, r is wrapped to add the missing interfaces. +func NewReader(r io.Reader) Reader { + if r, ok := r.(Reader); ok { + return r + } + return &byteReader{r} +} + +func (r *byteReader) ReadByte() (byte, error) { + var b [1]byte + n, err := r.Reader.Read(b[:]) + if n == 1 && err == io.EOF { + err = nil + } + return b[0], err +} + +// Writer implements both the io.ByteWriter and io.Writer interfaces. +type Writer interface { + io.ByteWriter + io.Writer +} + +var _ Writer = &bytes.Buffer{} + +type byteWriter struct { + io.Writer +} + +var _ Writer = &byteWriter{} + +// NewWriter returns a Writer for w. +// If r already implements both io.ByteWriter and io.Writer, NewWriter returns w. +// Otherwise, w is wrapped to add the missing interfaces. +func NewWriter(w io.Writer) Writer { + if w, ok := w.(Writer); ok { + return w + } + return &byteWriter{w} +} + +func (w *byteWriter) WriteByte(c byte) error { + _, err := w.Writer.Write([]byte{c}) + return err +} diff --git a/vendor/github.com/lucas-clemente/quic-go/quicvarint/varint.go b/vendor/github.com/lucas-clemente/quic-go/quicvarint/varint.go new file mode 100644 index 00000000..66e0d39e --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/quicvarint/varint.go @@ -0,0 +1,139 @@ +package quicvarint + +import ( + "fmt" + "io" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// taken from the QUIC draft +const ( + // Min is the minimum value allowed for a QUIC varint. + Min = 0 + + // Max is the maximum allowed value for a QUIC varint (2^62-1). + Max = maxVarInt8 + + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// Read reads a number in the QUIC varint format from r. +func Read(r io.ByteReader) (uint64, error) { + firstByte, err := r.ReadByte() + if err != nil { + return 0, err + } + // the first two bits of the first byte encode the length + len := 1 << ((firstByte & 0xc0) >> 6) + b1 := firstByte & (0xff - 0xc0) + if len == 1 { + return uint64(b1), nil + } + b2, err := r.ReadByte() + if err != nil { + return 0, err + } + if len == 2 { + return uint64(b2) + uint64(b1)<<8, nil + } + b3, err := r.ReadByte() + if err != nil { + return 0, err + } + b4, err := r.ReadByte() + if err != nil { + return 0, err + } + if len == 4 { + return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil + } + b5, err := r.ReadByte() + if err != nil { + return 0, err + } + b6, err := r.ReadByte() + if err != nil { + return 0, err + } + b7, err := r.ReadByte() + if err != nil { + return 0, err + } + b8, err := r.ReadByte() + if err != nil { + return 0, err + } + return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil +} + +// Write writes i in the QUIC varint format to w. +func Write(w Writer, i uint64) { + if i <= maxVarInt1 { + w.WriteByte(uint8(i)) + } else if i <= maxVarInt2 { + w.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) + } else if i <= maxVarInt4 { + w.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) + } else if i <= maxVarInt8 { + w.Write([]byte{ + uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }) + } else { + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) + } +} + +// WriteWithLen writes i in the QUIC varint format with the desired length to w. +func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) { + if length != 1 && length != 2 && length != 4 && length != 8 { + panic("invalid varint length") + } + l := Len(i) + if l == length { + Write(w, i) + return + } + if l > length { + panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) + } + if length == 2 { + w.WriteByte(0b01000000) + } else if length == 4 { + w.WriteByte(0b10000000) + } else if length == 8 { + w.WriteByte(0b11000000) + } + for j := protocol.ByteCount(1); j < length-l; j++ { + w.WriteByte(0) + } + for j := protocol.ByteCount(0); j < l; j++ { + w.WriteByte(uint8(i >> (8 * (l - 1 - j)))) + } +} + +// Len determines the number of bytes that will be needed to write the number i. +func Len(i uint64) protocol.ByteCount { + if i <= maxVarInt1 { + return 1 + } + if i <= maxVarInt2 { + return 2 + } + if i <= maxVarInt4 { + return 4 + } + if i <= maxVarInt8 { + return 8 + } + // Don't use a fmt.Sprintf here to format the error message. + // The function would then exceed the inlining budget. + panic(struct { + message string + num uint64 + }{"value doesn't fit into 62 bits: ", i}) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/receive_stream.go b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go new file mode 100644 index 00000000..ae6a449b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/receive_stream.go @@ -0,0 +1,331 @@ +package quic + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type receiveStreamI interface { + ReceiveStream + + handleStreamFrame(*wire.StreamFrame) error + handleResetStreamFrame(*wire.ResetStreamFrame) error + closeForShutdown(error) + getWindowUpdate() protocol.ByteCount +} + +type receiveStream struct { + mutex sync.Mutex + + streamID protocol.StreamID + + sender streamSender + + frameQueue *frameSorter + finalOffset protocol.ByteCount + + currentFrame []byte + currentFrameDone func() + currentFrameIsLast bool // is the currentFrame the last frame on this stream + readPosInFrame int + + closeForShutdownErr error + cancelReadErr error + resetRemotelyErr *StreamError + + closedForShutdown bool // set when CloseForShutdown() is called + finRead bool // set once we read a frame with a Fin + canceledRead bool // set when CancelRead() is called + resetRemotely bool // set when HandleResetStreamFrame() is called + + readChan chan struct{} + readOnce chan struct{} // cap: 1, to protect against concurrent use of Read + deadline time.Time + + flowController flowcontrol.StreamFlowController + version protocol.VersionNumber +} + +var ( + _ ReceiveStream = &receiveStream{} + _ receiveStreamI = &receiveStream{} +) + +func newReceiveStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *receiveStream { + return &receiveStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + frameQueue: newFrameSorter(), + readChan: make(chan struct{}, 1), + readOnce: make(chan struct{}, 1), + finalOffset: protocol.MaxByteCount, + version: version, + } +} + +func (s *receiveStream) StreamID() protocol.StreamID { + return s.streamID +} + +// Read implements io.Reader. It is not thread safe! +func (s *receiveStream) Read(p []byte) (int, error) { + // Concurrent use of Read is not permitted (and doesn't make any sense), + // but sometimes people do it anyway. + // Make sure that we only execute one call at any given time to avoid hard to debug failures. + s.readOnce <- struct{}{} + defer func() { <-s.readOnce }() + + s.mutex.Lock() + completed, n, err := s.readImpl(p) + s.mutex.Unlock() + + if completed { + s.sender.onStreamCompleted(s.streamID) + } + return n, err +} + +func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) { + if s.finRead { + return false, 0, io.EOF + } + if s.canceledRead { + return false, 0, s.cancelReadErr + } + if s.resetRemotely { + return false, 0, s.resetRemotelyErr + } + if s.closedForShutdown { + return false, 0, s.closeForShutdownErr + } + + var bytesRead int + var deadlineTimer *utils.Timer + for bytesRead < len(p) { + if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { + s.dequeueNextFrame() + } + if s.currentFrame == nil && bytesRead > 0 { + return false, bytesRead, s.closeForShutdownErr + } + + for { + // Stop waiting on errors + if s.closedForShutdown { + return false, bytesRead, s.closeForShutdownErr + } + if s.canceledRead { + return false, bytesRead, s.cancelReadErr + } + if s.resetRemotely { + return false, bytesRead, s.resetRemotelyErr + } + + deadline := s.deadline + if !deadline.IsZero() { + if !time.Now().Before(deadline) { + return false, bytesRead, errDeadline + } + if deadlineTimer == nil { + deadlineTimer = utils.NewTimer() + defer deadlineTimer.Stop() + } + deadlineTimer.Reset(deadline) + } + + if s.currentFrame != nil || s.currentFrameIsLast { + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.readChan + } else { + select { + case <-s.readChan: + case <-deadlineTimer.Chan(): + deadlineTimer.SetRead() + } + } + s.mutex.Lock() + if s.currentFrame == nil { + s.dequeueNextFrame() + } + } + + if bytesRead > len(p) { + return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p)) + } + if s.readPosInFrame > len(s.currentFrame) { + return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame)) + } + + m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:]) + s.readPosInFrame += m + bytesRead += m + + // when a RESET_STREAM was received, the was already informed about the final byteOffset for this stream + if !s.resetRemotely { + s.flowController.AddBytesRead(protocol.ByteCount(m)) + } + + if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { + s.finRead = true + return true, bytesRead, io.EOF + } + } + return false, bytesRead, nil +} + +func (s *receiveStream) dequeueNextFrame() { + var offset protocol.ByteCount + // We're done with the last frame. Release the buffer. + if s.currentFrameDone != nil { + s.currentFrameDone() + } + offset, s.currentFrame, s.currentFrameDone = s.frameQueue.Pop() + s.currentFrameIsLast = offset+protocol.ByteCount(len(s.currentFrame)) >= s.finalOffset + s.readPosInFrame = 0 +} + +func (s *receiveStream) CancelRead(errorCode StreamErrorCode) { + s.mutex.Lock() + completed := s.cancelReadImpl(errorCode) + s.mutex.Unlock() + + if completed { + s.flowController.Abandon() + s.sender.onStreamCompleted(s.streamID) + } +} + +func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ { + if s.finRead || s.canceledRead || s.resetRemotely { + return false + } + s.canceledRead = true + s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) + s.signalRead() + s.sender.queueControlFrame(&wire.StopSendingFrame{ + StreamID: s.streamID, + ErrorCode: errorCode, + }) + // We're done with this stream if the final offset was already received. + return s.finalOffset != protocol.MaxByteCount +} + +func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error { + s.mutex.Lock() + completed, err := s.handleStreamFrameImpl(frame) + s.mutex.Unlock() + + if completed { + s.flowController.Abandon() + s.sender.onStreamCompleted(s.streamID) + } + return err +} + +func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) { + maxOffset := frame.Offset + frame.DataLen() + if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil { + return false, err + } + var newlyRcvdFinalOffset bool + if frame.Fin { + newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount + s.finalOffset = maxOffset + } + if s.canceledRead { + return newlyRcvdFinalOffset, nil + } + if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil { + return false, err + } + s.signalRead() + return false, nil +} + +func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { + s.mutex.Lock() + completed, err := s.handleResetStreamFrameImpl(frame) + s.mutex.Unlock() + + if completed { + s.flowController.Abandon() + s.sender.onStreamCompleted(s.streamID) + } + return err +} + +func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) { + if s.closedForShutdown { + return false, nil + } + if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil { + return false, err + } + newlyRcvdFinalOffset := s.finalOffset == protocol.MaxByteCount + s.finalOffset = frame.FinalSize + + // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset) + if s.resetRemotely { + return false, nil + } + s.resetRemotely = true + s.resetRemotelyErr = &StreamError{ + StreamID: s.streamID, + ErrorCode: frame.ErrorCode, + } + s.signalRead() + return newlyRcvdFinalOffset, nil +} + +func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { + s.handleStreamFrame(&wire.StreamFrame{Fin: true, Offset: offset}) +} + +func (s *receiveStream) SetReadDeadline(t time.Time) error { + s.mutex.Lock() + s.deadline = t + s.mutex.Unlock() + s.signalRead() + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Read unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RESET. +func (s *receiveStream) closeForShutdown(err error) { + s.mutex.Lock() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalRead() +} + +func (s *receiveStream) getWindowUpdate() protocol.ByteCount { + return s.flowController.GetWindowUpdate() +} + +// signalRead performs a non-blocking send on the readChan +func (s *receiveStream) signalRead() { + select { + case s.readChan <- struct{}{}: + default: + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/retransmission_queue.go b/vendor/github.com/lucas-clemente/quic-go/retransmission_queue.go new file mode 100644 index 00000000..0cfbbc4d --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/retransmission_queue.go @@ -0,0 +1,131 @@ +package quic + +import ( + "fmt" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type retransmissionQueue struct { + initial []wire.Frame + initialCryptoData []*wire.CryptoFrame + + handshake []wire.Frame + handshakeCryptoData []*wire.CryptoFrame + + appData []wire.Frame + + version protocol.VersionNumber +} + +func newRetransmissionQueue(ver protocol.VersionNumber) *retransmissionQueue { + return &retransmissionQueue{version: ver} +} + +func (q *retransmissionQueue) AddInitial(f wire.Frame) { + if cf, ok := f.(*wire.CryptoFrame); ok { + q.initialCryptoData = append(q.initialCryptoData, cf) + return + } + q.initial = append(q.initial, f) +} + +func (q *retransmissionQueue) AddHandshake(f wire.Frame) { + if cf, ok := f.(*wire.CryptoFrame); ok { + q.handshakeCryptoData = append(q.handshakeCryptoData, cf) + return + } + q.handshake = append(q.handshake, f) +} + +func (q *retransmissionQueue) HasInitialData() bool { + return len(q.initialCryptoData) > 0 || len(q.initial) > 0 +} + +func (q *retransmissionQueue) HasHandshakeData() bool { + return len(q.handshakeCryptoData) > 0 || len(q.handshake) > 0 +} + +func (q *retransmissionQueue) HasAppData() bool { + return len(q.appData) > 0 +} + +func (q *retransmissionQueue) AddAppData(f wire.Frame) { + if _, ok := f.(*wire.StreamFrame); ok { + panic("STREAM frames are handled with their respective streams.") + } + q.appData = append(q.appData, f) +} + +func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.initialCryptoData) > 0 { + f := q.initialCryptoData[0] + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + if newFrame == nil && !needsSplit { // the whole frame fits + q.initialCryptoData = q.initialCryptoData[1:] + return f + } + if newFrame != nil { // frame was split. Leave the original frame in the queue. + return newFrame + } + } + if len(q.initial) == 0 { + return nil + } + f := q.initial[0] + if f.Length(q.version) > maxLen { + return nil + } + q.initial = q.initial[1:] + return f +} + +func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.handshakeCryptoData) > 0 { + f := q.handshakeCryptoData[0] + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + if newFrame == nil && !needsSplit { // the whole frame fits + q.handshakeCryptoData = q.handshakeCryptoData[1:] + return f + } + if newFrame != nil { // frame was split. Leave the original frame in the queue. + return newFrame + } + } + if len(q.handshake) == 0 { + return nil + } + f := q.handshake[0] + if f.Length(q.version) > maxLen { + return nil + } + q.handshake = q.handshake[1:] + return f +} + +func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount) wire.Frame { + if len(q.appData) == 0 { + return nil + } + f := q.appData[0] + if f.Length(q.version) > maxLen { + return nil + } + q.appData = q.appData[1:] + return f +} + +func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) { + //nolint:exhaustive // Can only drop Initial and Handshake packet number space. + switch encLevel { + case protocol.EncryptionInitial: + q.initial = nil + q.initialCryptoData = nil + case protocol.EncryptionHandshake: + q.handshake = nil + q.handshakeCryptoData = nil + default: + panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/send_conn.go b/vendor/github.com/lucas-clemente/quic-go/send_conn.go new file mode 100644 index 00000000..c53ebdfa --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/send_conn.go @@ -0,0 +1,74 @@ +package quic + +import ( + "net" +) + +// A sendConn allows sending using a simple Write() on a non-connected packet conn. +type sendConn interface { + Write([]byte) error + Close() error + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +type sconn struct { + rawConn + + remoteAddr net.Addr + info *packetInfo + oob []byte +} + +var _ sendConn = &sconn{} + +func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { + return &sconn{ + rawConn: c, + remoteAddr: remote, + info: info, + oob: info.OOB(), + } +} + +func (c *sconn) Write(p []byte) error { + _, err := c.WritePacket(p, c.remoteAddr, c.oob) + return err +} + +func (c *sconn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *sconn) LocalAddr() net.Addr { + addr := c.rawConn.LocalAddr() + if c.info != nil { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + addrCopy := *udpAddr + addrCopy.IP = c.info.addr + addr = &addrCopy + } + } + return addr +} + +type spconn struct { + net.PacketConn + + remoteAddr net.Addr +} + +var _ sendConn = &spconn{} + +func newSendPconn(c net.PacketConn, remote net.Addr) sendConn { + return &spconn{PacketConn: c, remoteAddr: remote} +} + +func (c *spconn) Write(p []byte) error { + _, err := c.WriteTo(p, c.remoteAddr) + return err +} + +func (c *spconn) RemoteAddr() net.Addr { + return c.remoteAddr +} diff --git a/vendor/github.com/lucas-clemente/quic-go/send_queue.go b/vendor/github.com/lucas-clemente/quic-go/send_queue.go new file mode 100644 index 00000000..1fc8c1bf --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/send_queue.go @@ -0,0 +1,88 @@ +package quic + +type sender interface { + Send(p *packetBuffer) + Run() error + WouldBlock() bool + Available() <-chan struct{} + Close() +} + +type sendQueue struct { + queue chan *packetBuffer + closeCalled chan struct{} // runStopped when Close() is called + runStopped chan struct{} // runStopped when the run loop returns + available chan struct{} + conn sendConn +} + +var _ sender = &sendQueue{} + +const sendQueueCapacity = 8 + +func newSendQueue(conn sendConn) sender { + return &sendQueue{ + conn: conn, + runStopped: make(chan struct{}), + closeCalled: make(chan struct{}), + available: make(chan struct{}, 1), + queue: make(chan *packetBuffer, sendQueueCapacity), + } +} + +// Send sends out a packet. It's guaranteed to not block. +// Callers need to make sure that there's actually space in the send queue by calling WouldBlock. +// Otherwise Send will panic. +func (h *sendQueue) Send(p *packetBuffer) { + select { + case h.queue <- p: + case <-h.runStopped: + default: + panic("sendQueue.Send would have blocked") + } +} + +func (h *sendQueue) WouldBlock() bool { + return len(h.queue) == sendQueueCapacity +} + +func (h *sendQueue) Available() <-chan struct{} { + return h.available +} + +func (h *sendQueue) Run() error { + defer close(h.runStopped) + var shouldClose bool + for { + if shouldClose && len(h.queue) == 0 { + return nil + } + select { + case <-h.closeCalled: + h.closeCalled = nil // prevent this case from being selected again + // make sure that all queued packets are actually sent out + shouldClose = true + case p := <-h.queue: + if err := h.conn.Write(p.Data); err != nil { + // This additional check enables: + // 1. Checking for "datagram too large" message from the kernel, as such, + // 2. Path MTU discovery,and + // 3. Eventual detection of loss PingFrame. + if !isMsgSizeErr(err) { + return err + } + } + p.Release() + select { + case h.available <- struct{}{}: + default: + } + } + } +} + +func (h *sendQueue) Close() { + close(h.closeCalled) + // wait until the run loop returned + <-h.runStopped +} diff --git a/vendor/github.com/lucas-clemente/quic-go/send_stream.go b/vendor/github.com/lucas-clemente/quic-go/send_stream.go new file mode 100644 index 00000000..b23df00b --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/send_stream.go @@ -0,0 +1,496 @@ +package quic + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type sendStreamI interface { + SendStream + handleStopSendingFrame(*wire.StopSendingFrame) + hasData() bool + popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) + closeForShutdown(error) + updateSendWindow(protocol.ByteCount) +} + +type sendStream struct { + mutex sync.Mutex + + numOutstandingFrames int64 + retransmissionQueue []*wire.StreamFrame + + ctx context.Context + ctxCancel context.CancelFunc + + streamID protocol.StreamID + sender streamSender + + writeOffset protocol.ByteCount + + cancelWriteErr error + closeForShutdownErr error + + closedForShutdown bool // set when CloseForShutdown() is called + finishedWriting bool // set once Close() is called + canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received + finSent bool // set when a STREAM_FRAME with FIN bit has been sent + completed bool // set when this stream has been reported to the streamSender as completed + + dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out + nextFrame *wire.StreamFrame + + writeChan chan struct{} + writeOnce chan struct{} + deadline time.Time + + flowController flowcontrol.StreamFlowController + + version protocol.VersionNumber +} + +var ( + _ SendStream = &sendStream{} + _ sendStreamI = &sendStream{} +) + +func newSendStream( + streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *sendStream { + s := &sendStream{ + streamID: streamID, + sender: sender, + flowController: flowController, + writeChan: make(chan struct{}, 1), + writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write + version: version, + } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + return s +} + +func (s *sendStream) StreamID() protocol.StreamID { + return s.streamID // same for receiveStream and sendStream +} + +func (s *sendStream) Write(p []byte) (int, error) { + // Concurrent use of Write is not permitted (and doesn't make any sense), + // but sometimes people do it anyway. + // Make sure that we only execute one call at any given time to avoid hard to debug failures. + s.writeOnce <- struct{}{} + defer func() { <-s.writeOnce }() + + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.finishedWriting { + return 0, fmt.Errorf("write on closed stream %d", s.streamID) + } + if s.canceledWrite { + return 0, s.cancelWriteErr + } + if s.closeForShutdownErr != nil { + return 0, s.closeForShutdownErr + } + if !s.deadline.IsZero() && !time.Now().Before(s.deadline) { + return 0, errDeadline + } + if len(p) == 0 { + return 0, nil + } + + s.dataForWriting = p + + var ( + deadlineTimer *utils.Timer + bytesWritten int + notifiedSender bool + ) + for { + var copied bool + var deadline time.Time + // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame), + // which can the be popped the next time we assemble a packet. + // This allows us to return Write() when all data but x bytes have been sent out. + // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame, + // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN). + if s.canBufferStreamFrame() && len(s.dataForWriting) > 0 { + if s.nextFrame == nil { + f := wire.GetStreamFrame() + f.Offset = s.writeOffset + f.StreamID = s.streamID + f.DataLenPresent = true + f.Data = f.Data[:len(s.dataForWriting)] + copy(f.Data, s.dataForWriting) + s.nextFrame = f + } else { + l := len(s.nextFrame.Data) + s.nextFrame.Data = s.nextFrame.Data[:l+len(s.dataForWriting)] + copy(s.nextFrame.Data[l:], s.dataForWriting) + } + s.dataForWriting = nil + bytesWritten = len(p) + copied = true + } else { + bytesWritten = len(p) - len(s.dataForWriting) + deadline = s.deadline + if !deadline.IsZero() { + if !time.Now().Before(deadline) { + s.dataForWriting = nil + return bytesWritten, errDeadline + } + if deadlineTimer == nil { + deadlineTimer = utils.NewTimer() + defer deadlineTimer.Stop() + } + deadlineTimer.Reset(deadline) + } + if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { + break + } + } + + s.mutex.Unlock() + if !notifiedSender { + s.sender.onHasStreamData(s.streamID) // must be called without holding the mutex + notifiedSender = true + } + if copied { + s.mutex.Lock() + break + } + if deadline.IsZero() { + <-s.writeChan + } else { + select { + case <-s.writeChan: + case <-deadlineTimer.Chan(): + deadlineTimer.SetRead() + } + } + s.mutex.Lock() + } + + if bytesWritten == len(p) { + return bytesWritten, nil + } + if s.closeForShutdownErr != nil { + return bytesWritten, s.closeForShutdownErr + } else if s.cancelWriteErr != nil { + return bytesWritten, s.cancelWriteErr + } + return bytesWritten, nil +} + +func (s *sendStream) canBufferStreamFrame() bool { + var l protocol.ByteCount + if s.nextFrame != nil { + l = s.nextFrame.DataLen() + } + return l+protocol.ByteCount(len(s.dataForWriting)) <= protocol.MaxPacketBufferSize +} + +// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream +// maxBytes is the maximum length this frame (including frame header) will have. +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) { + s.mutex.Lock() + f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes) + if f != nil { + s.numOutstandingFrames++ + } + s.mutex.Unlock() + + if f == nil { + return nil, hasMoreData + } + return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData +} + +func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { + if s.canceledWrite || s.closeForShutdownErr != nil { + return nil, false + } + + if len(s.retransmissionQueue) > 0 { + f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes) + if f != nil || hasMoreRetransmissions { + if f == nil { + return nil, true + } + // We always claim that we have more data to send. + // This might be incorrect, in which case there'll be a spurious call to popStreamFrame in the future. + return f, true + } + } + + if len(s.dataForWriting) == 0 && s.nextFrame == nil { + if s.finishedWriting && !s.finSent { + s.finSent = true + return &wire.StreamFrame{ + StreamID: s.streamID, + Offset: s.writeOffset, + DataLenPresent: true, + Fin: true, + }, false + } + return nil, false + } + + sendWindow := s.flowController.SendWindowSize() + if sendWindow == 0 { + if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked { + s.sender.queueControlFrame(&wire.StreamDataBlockedFrame{ + StreamID: s.streamID, + MaximumStreamData: offset, + }) + return nil, false + } + return nil, true + } + + f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow) + if dataLen := f.DataLen(); dataLen > 0 { + s.writeOffset += f.DataLen() + s.flowController.AddBytesSent(f.DataLen()) + } + f.Fin = s.finishedWriting && s.dataForWriting == nil && s.nextFrame == nil && !s.finSent + if f.Fin { + s.finSent = true + } + return f, hasMoreData +} + +func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) { + if s.nextFrame != nil { + nextFrame := s.nextFrame + s.nextFrame = nil + + maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version)) + if nextFrame.DataLen() > maxDataLen { + s.nextFrame = wire.GetStreamFrame() + s.nextFrame.StreamID = s.streamID + s.nextFrame.Offset = s.writeOffset + maxDataLen + s.nextFrame.Data = s.nextFrame.Data[:nextFrame.DataLen()-maxDataLen] + s.nextFrame.DataLenPresent = true + copy(s.nextFrame.Data, nextFrame.Data[maxDataLen:]) + nextFrame.Data = nextFrame.Data[:maxDataLen] + } else { + s.signalWrite() + } + return nextFrame, s.nextFrame != nil || s.dataForWriting != nil + } + + f := wire.GetStreamFrame() + f.Fin = false + f.StreamID = s.streamID + f.Offset = s.writeOffset + f.DataLenPresent = true + f.Data = f.Data[:0] + + hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow) + if len(f.Data) == 0 && !f.Fin { + f.PutBack() + return nil, hasMoreData + } + return f, hasMoreData +} + +func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool { + maxDataLen := f.MaxDataLen(maxBytes, s.version) + if maxDataLen == 0 { // a STREAM frame must have at least one byte of data + return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting + } + s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow)) + + return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting +} + +func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) { + f := s.retransmissionQueue[0] + newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version) + if needsSplit { + return newFrame, true + } + s.retransmissionQueue = s.retransmissionQueue[1:] + return f, len(s.retransmissionQueue) > 0 +} + +func (s *sendStream) hasData() bool { + s.mutex.Lock() + hasData := len(s.dataForWriting) > 0 + s.mutex.Unlock() + return hasData +} + +func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.ByteCount) { + if protocol.ByteCount(len(s.dataForWriting)) <= maxBytes { + f.Data = f.Data[:len(s.dataForWriting)] + copy(f.Data, s.dataForWriting) + s.dataForWriting = nil + s.signalWrite() + return + } + f.Data = f.Data[:maxBytes] + copy(f.Data, s.dataForWriting) + s.dataForWriting = s.dataForWriting[maxBytes:] + if s.canBufferStreamFrame() { + s.signalWrite() + } +} + +func (s *sendStream) frameAcked(f wire.Frame) { + f.(*wire.StreamFrame).PutBack() + + s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() + return + } + s.numOutstandingFrames-- + if s.numOutstandingFrames < 0 { + panic("numOutStandingFrames negative") + } + newlyCompleted := s.isNewlyCompleted() + s.mutex.Unlock() + + if newlyCompleted { + s.sender.onStreamCompleted(s.streamID) + } +} + +func (s *sendStream) isNewlyCompleted() bool { + completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 + if completed && !s.completed { + s.completed = true + return true + } + return false +} + +func (s *sendStream) queueRetransmission(f wire.Frame) { + sf := f.(*wire.StreamFrame) + sf.DataLenPresent = true + s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() + return + } + s.retransmissionQueue = append(s.retransmissionQueue, sf) + s.numOutstandingFrames-- + if s.numOutstandingFrames < 0 { + panic("numOutStandingFrames negative") + } + s.mutex.Unlock() + + s.sender.onHasStreamData(s.streamID) +} + +func (s *sendStream) Close() error { + s.mutex.Lock() + if s.closedForShutdown { + s.mutex.Unlock() + return nil + } + if s.canceledWrite { + s.mutex.Unlock() + return fmt.Errorf("close called for canceled stream %d", s.streamID) + } + s.ctxCancel() + s.finishedWriting = true + s.mutex.Unlock() + + s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex + return nil +} + +func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { + s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) +} + +// must be called after locking the mutex +func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) { + s.mutex.Lock() + if s.canceledWrite { + s.mutex.Unlock() + return + } + s.ctxCancel() + s.canceledWrite = true + s.cancelWriteErr = writeErr + s.numOutstandingFrames = 0 + s.retransmissionQueue = nil + newlyCompleted := s.isNewlyCompleted() + s.mutex.Unlock() + + s.signalWrite() + s.sender.queueControlFrame(&wire.ResetStreamFrame{ + StreamID: s.streamID, + FinalSize: s.writeOffset, + ErrorCode: errorCode, + }) + if newlyCompleted { + s.sender.onStreamCompleted(s.streamID) + } +} + +func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { + s.mutex.Lock() + hasStreamData := s.dataForWriting != nil || s.nextFrame != nil + s.mutex.Unlock() + + s.flowController.UpdateSendWindow(limit) + if hasStreamData { + s.sender.onHasStreamData(s.streamID) + } +} + +func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { + s.cancelWriteImpl(frame.ErrorCode, &StreamError{ + StreamID: s.streamID, + ErrorCode: frame.ErrorCode, + }) +} + +func (s *sendStream) Context() context.Context { + return s.ctx +} + +func (s *sendStream) SetWriteDeadline(t time.Time) error { + s.mutex.Lock() + s.deadline = t + s.mutex.Unlock() + s.signalWrite() + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *sendStream) closeForShutdown(err error) { + s.mutex.Lock() + s.ctxCancel() + s.closedForShutdown = true + s.closeForShutdownErr = err + s.mutex.Unlock() + s.signalWrite() +} + +// signalWrite performs a non-blocking send on the writeChan +func (s *sendStream) signalWrite() { + select { + case s.writeChan <- struct{}{}: + default: + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/server.go b/vendor/github.com/lucas-clemente/quic-go/server.go new file mode 100644 index 00000000..0e642970 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/server.go @@ -0,0 +1,670 @@ +package quic + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/logging" +) + +// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. +var ErrServerClosed = errors.New("quic: Server closed") + +// packetHandler handles packets +type packetHandler interface { + handlePacket(*receivedPacket) + shutdown() + destroy(error) + getPerspective() protocol.Perspective +} + +type unknownPacketHandler interface { + handlePacket(*receivedPacket) + setCloseError(error) +} + +type packetHandlerManager interface { + AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool + Destroy() error + connRunner + SetServer(unknownPacketHandler) + CloseServer() +} + +type quicConn interface { + EarlyConnection + earlyConnReady() <-chan struct{} + handlePacket(*receivedPacket) + GetVersion() protocol.VersionNumber + getPerspective() protocol.Perspective + run() error + destroy(error) + shutdown() +} + +// A Listener of QUIC +type baseServer struct { + mutex sync.Mutex + + acceptEarlyConns bool + + tlsConf *tls.Config + config *Config + + conn rawConn + // If the server is started with ListenAddr, we create a packet conn. + // If it is started with Listen, we take a packet conn as a parameter. + createdPacketConn bool + + tokenGenerator *handshake.TokenGenerator + + connHandler packetHandlerManager + + receivedPackets chan *receivedPacket + + // set as a member, so they can be set in the tests + newConn func( + sendConn, + connRunner, + protocol.ConnectionID, /* original dest connection ID */ + *protocol.ConnectionID, /* retry src connection ID */ + protocol.ConnectionID, /* client dest connection ID */ + protocol.ConnectionID, /* destination connection ID */ + protocol.ConnectionID, /* source connection ID */ + protocol.StatelessResetToken, + *Config, + *tls.Config, + *handshake.TokenGenerator, + bool, /* enable 0-RTT */ + logging.ConnectionTracer, + uint64, + utils.Logger, + protocol.VersionNumber, + ) quicConn + + serverError error + errorChan chan struct{} + closed bool + running chan struct{} // closed as soon as run() returns + + connQueue chan quicConn + connQueueLen int32 // to be used as an atomic + + logger utils.Logger +} + +var ( + _ Listener = &baseServer{} + _ unknownPacketHandler = &baseServer{} +) + +type earlyServer struct{ *baseServer } + +var _ EarlyListener = &earlyServer{} + +func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { + return s.baseServer.accept(ctx) +} + +// ListenAddr creates a QUIC server listening on a given address. +// The tls.Config must not be nil and must contain a certificate configuration. +// The quic.Config may be nil, in that case the default values will be used. +func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { + return listenAddr(addr, tlsConf, config, false) +} + +// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. +func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { + s, err := listenAddr(addr, tlsConf, config, true) + if err != nil { + return nil, err + } + return &earlyServer{s}, nil +} + +func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, err + } + serv, err := listen(conn, tlsConf, config, acceptEarly) + if err != nil { + return nil, err + } + serv.createdPacketConn = true + return serv, nil +} + +// Listen listens for QUIC connections on a given net.PacketConn. If the +// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn +// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP +// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write +// packets. A single net.PacketConn only be used for a single call to Listen. +// The PacketConn can be used for simultaneous calls to Dial. QUIC connection +// IDs are used for demultiplexing the different connections. The tls.Config +// must not be nil and must contain a certificate configuration. The +// tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. Furthermore, +// it must define an application control (using NextProtos). The quic.Config may +// be nil, in that case the default values will be used. +func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { + return listen(conn, tlsConf, config, false) +} + +// ListenEarly works like Listen, but it returns connections before the handshake completes. +func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { + s, err := listen(conn, tlsConf, config, true) + if err != nil { + return nil, err + } + return &earlyServer{s}, nil +} + +func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(config); err != nil { + return nil, err + } + config = populateServerConfig(config) + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return nil, fmt.Errorf("%s is not a valid QUIC version", v) + } + } + + connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + if err != nil { + return nil, err + } + tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) + if err != nil { + return nil, err + } + c, err := wrapConn(conn) + if err != nil { + return nil, err + } + s := &baseServer{ + conn: c, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), + newConn: newConnection, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + } + go s.run() + connHandler.SetServer(s) + s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) + return s, nil +} + +func (s *baseServer) run() { + defer close(s.running) + for { + select { + case <-s.errorChan: + return + default: + } + select { + case <-s.errorChan: + return + case p := <-s.receivedPackets: + if bufferStillInUse := s.handlePacketImpl(p); !bufferStillInUse { + p.buffer.Release() + } + } + } +} + +var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { + if token == nil { + return false + } + validity := protocol.TokenValidity + if token.IsRetryToken { + validity = protocol.RetryTokenValidity + } + if time.Now().After(token.SentTime.Add(validity)) { + return false + } + var sourceAddr string + if udpAddr, ok := clientAddr.(*net.UDPAddr); ok { + sourceAddr = udpAddr.IP.String() + } else { + sourceAddr = clientAddr.String() + } + return sourceAddr == token.RemoteAddr +} + +// Accept returns connections that already completed the handshake. +// It is only valid if acceptEarlyConns is false. +func (s *baseServer) Accept(ctx context.Context) (Connection, error) { + return s.accept(ctx) +} + +func (s *baseServer) accept(ctx context.Context) (quicConn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case conn := <-s.connQueue: + atomic.AddInt32(&s.connQueueLen, -1) + return conn, nil + case <-s.errorChan: + return nil, s.serverError + } +} + +// Close the server +func (s *baseServer) Close() error { + s.mutex.Lock() + if s.closed { + s.mutex.Unlock() + return nil + } + if s.serverError == nil { + s.serverError = ErrServerClosed + } + // If the server was started with ListenAddr, we created the packet conn. + // We need to close it in order to make the go routine reading from that conn return. + createdPacketConn := s.createdPacketConn + s.closed = true + close(s.errorChan) + s.mutex.Unlock() + + <-s.running + s.connHandler.CloseServer() + if createdPacketConn { + return s.connHandler.Destroy() + } + return nil +} + +func (s *baseServer) setCloseError(e error) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.closed { + return + } + s.closed = true + s.serverError = e + close(s.errorChan) +} + +// Addr returns the server's network address +func (s *baseServer) Addr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *baseServer) handlePacket(p *receivedPacket) { + select { + case s.receivedPackets <- p: + default: + s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + } + } +} + +func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { + if wire.IsVersionNegotiationPacket(p.data) { + s.logger.Debugf("Dropping Version Negotiation packet.") + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + // If we're creating a new connection, the packet will be passed to the connection. + // The header will then be parsed again. + hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) + if err != nil && err != wire.ErrUnsupportedVersion { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Error parsing packet: %s", err) + return false + } + // Short header packets should never end up here in the first place + if !hdr.IsLongHeader { + panic(fmt.Sprintf("misrouted packet: %#v", hdr)) + } + if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { + s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + // send a Version Negotiation Packet if the client is speaking a different protocol version + if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + if p.Size() < protocol.MinUnknownVersionPacketSize { + s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + if !s.config.DisableVersionNegotiationPackets { + go s.sendVersionNegotiationPacket(p, hdr) + } + return false + } + if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { + // Drop long header packets. + // There's little point in sending a Stateless Reset, since the client + // might not have received the token yet. + s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + + s.logger.Debugf("<- Received Initial packet.") + + if err := s.handleInitialImpl(p, hdr); err != nil { + s.logger.Errorf("Error occurred handling initial packet: %s", err) + } + // Don't put the packet buffer back. + // handleInitialImpl deals with the buffer. + return true +} + +func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { + if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { + p.buffer.Release() + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + } + return errors.New("too short connection ID") + } + + var ( + token *Token + retrySrcConnID *protocol.ConnectionID + ) + origDestConnID := hdr.DestConnectionID + if len(hdr.Token) > 0 { + c, err := s.tokenGenerator.DecodeToken(hdr.Token) + if err == nil { + token = &Token{ + IsRetryToken: c.IsRetryToken, + RemoteAddr: c.RemoteAddr, + SentTime: c.SentTime, + } + if token.IsRetryToken { + origDestConnID = c.OriginalDestConnectionID + retrySrcConnID = &c.RetrySrcConnectionID + } + } + } + if !s.config.AcceptToken(p.remoteAddr, token) { + go func() { + defer p.buffer.Release() + if token != nil && token.IsRetryToken { + if err := s.maybeSendInvalidToken(p, hdr); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) + } + return + } + if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error sending Retry: %s", err) + } + }() + return nil + } + + if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize { + s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) + go func() { + defer p.buffer.Release() + if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error rejecting connection: %s", err) + } + }() + return nil + } + + connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + if err != nil { + return err + } + s.logger.Debugf("Changing connection ID to %s.", connID) + var conn quicConn + tracingID := nextConnTracingID() + if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { + var tracer logging.ConnectionTracer + if s.config.Tracer != nil { + // Use the same connection ID that is passed to the client's GetLogWriter callback. + connID := hdr.DestConnectionID + if origDestConnID.Len() > 0 { + connID = origDestConnID + } + tracer = s.config.Tracer.TracerForConnection( + context.WithValue(context.Background(), ConnectionTracingKey, tracingID), + protocol.PerspectiveServer, + connID, + ) + } + conn = s.newConn( + newSendConn(s.conn, p.remoteAddr, p.info), + s.connHandler, + origDestConnID, + retrySrcConnID, + hdr.DestConnectionID, + hdr.SrcConnectionID, + connID, + s.connHandler.GetStatelessResetToken(connID), + s.config, + s.tlsConf, + s.tokenGenerator, + s.acceptEarlyConns, + tracer, + tracingID, + s.logger, + hdr.Version, + ) + conn.handlePacket(p) + return conn + }); !added { + return nil + } + go conn.run() + go s.handleNewConn(conn) + if conn == nil { + p.buffer.Release() + return nil + } + return nil +} + +func (s *baseServer) handleNewConn(conn quicConn) { + connCtx := conn.Context() + if s.acceptEarlyConns { + // wait until the early connection is ready (or the handshake fails) + select { + case <-conn.earlyConnReady(): + case <-connCtx.Done(): + return + } + } else { + // wait until the handshake is complete (or fails) + select { + case <-conn.HandshakeComplete().Done(): + case <-connCtx.Done(): + return + } + } + + atomic.AddInt32(&s.connQueueLen, 1) + select { + case s.connQueue <- conn: + // blocks until the connection is accepted + case <-connCtx.Done(): + atomic.AddInt32(&s.connQueueLen, -1) + // don't pass connections that were already closed to Accept() + } +} + +func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { + // Log the Initial packet now. + // If no Retry is sent, the packet will be logged by the connection. + (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) + srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) + if err != nil { + return err + } + token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID, srcConnID) + if err != nil { + return err + } + replyHdr := &wire.ExtendedHeader{} + replyHdr.IsLongHeader = true + replyHdr.Type = protocol.PacketTypeRetry + replyHdr.Version = hdr.Version + replyHdr.SrcConnectionID = srcConnID + replyHdr.DestConnectionID = hdr.SrcConnectionID + replyHdr.Token = token + if s.logger.Debug() { + s.logger.Debugf("Changing connection ID to %s.", srcConnID) + s.logger.Debugf("-> Sending Retry") + replyHdr.Log(s.logger) + } + + packetBuffer := getPacketBuffer() + defer packetBuffer.Release() + buf := bytes.NewBuffer(packetBuffer.Data) + if err := replyHdr.Write(buf, hdr.Version); err != nil { + return err + } + // append the Retry integrity tag + tag := handshake.GetRetryIntegrityTag(buf.Bytes(), hdr.DestConnectionID, hdr.Version) + buf.Write(tag[:]) + if s.config.Tracer != nil { + s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil) + } + _, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB()) + return err +} + +func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { + // Only send INVALID_TOKEN if we can unprotect the packet. + // This makes sure that we won't send it for packets that were corrupted. + sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) + data := p.data[:hdr.ParsedLen()+hdr.Length] + extHdr, err := unpackHeader(opener, hdr, data, hdr.Version) + if err != nil { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) + } + // don't return the error here. Just drop the packet. + return nil + } + hdrLen := extHdr.ParsedLen() + if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { + // don't return the error here. Just drop the packet. + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) + } + return nil + } + if s.logger.Debug() { + s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) + } + return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info) +} + +func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { + sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) + return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) +} + +// sendError sends the error as a response to the packet received with header hdr +func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { + packetBuffer := getPacketBuffer() + defer packetBuffer.Release() + buf := bytes.NewBuffer(packetBuffer.Data) + + ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} + + replyHdr := &wire.ExtendedHeader{} + replyHdr.IsLongHeader = true + replyHdr.Type = protocol.PacketTypeInitial + replyHdr.Version = hdr.Version + replyHdr.SrcConnectionID = hdr.DestConnectionID + replyHdr.DestConnectionID = hdr.SrcConnectionID + replyHdr.PacketNumberLen = protocol.PacketNumberLen4 + replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) + if err := replyHdr.Write(buf, hdr.Version); err != nil { + return err + } + payloadOffset := buf.Len() + + if err := ccf.Write(buf, hdr.Version); err != nil { + return err + } + + raw := buf.Bytes() + _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset]) + raw = raw[0 : buf.Len()+sealer.Overhead()] + + pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) + sealer.EncryptHeader( + raw[pnOffset+4:pnOffset+4+16], + &raw[0], + raw[pnOffset:payloadOffset], + ) + + replyHdr.Log(s.logger) + wire.LogFrame(s.logger, ccf, true) + if s.config.Tracer != nil { + s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) + } + _, err := s.conn.WritePacket(raw, remoteAddr, info.OOB()) + return err +} + +func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { + s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) + data := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) + if s.config.Tracer != nil { + s.config.Tracer.SentPacket( + p.remoteAddr, + &wire.Header{ + IsLongHeader: true, + DestConnectionID: hdr.SrcConnectionID, + SrcConnectionID: hdr.DestConnectionID, + }, + protocol.ByteCount(len(data)), + nil, + ) + } + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + s.logger.Debugf("Error sending Version Negotiation: %s", err) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/stream.go b/vendor/github.com/lucas-clemente/quic-go/stream.go new file mode 100644 index 00000000..95bbcb35 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/stream.go @@ -0,0 +1,149 @@ +package quic + +import ( + "net" + "os" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/ackhandler" + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type deadlineError struct{} + +func (deadlineError) Error() string { return "deadline exceeded" } +func (deadlineError) Temporary() bool { return true } +func (deadlineError) Timeout() bool { return true } +func (deadlineError) Unwrap() error { return os.ErrDeadlineExceeded } + +var errDeadline net.Error = &deadlineError{} + +// The streamSender is notified by the stream about various events. +type streamSender interface { + queueControlFrame(wire.Frame) + onHasStreamData(protocol.StreamID) + // must be called without holding the mutex that is acquired by closeForShutdown + onStreamCompleted(protocol.StreamID) +} + +// Each of the both stream halves gets its own uniStreamSender. +// This is necessary in order to keep track when both halves have been completed. +type uniStreamSender struct { + streamSender + onStreamCompletedImpl func() +} + +func (s *uniStreamSender) queueControlFrame(f wire.Frame) { + s.streamSender.queueControlFrame(f) +} + +func (s *uniStreamSender) onHasStreamData(id protocol.StreamID) { + s.streamSender.onHasStreamData(id) +} + +func (s *uniStreamSender) onStreamCompleted(protocol.StreamID) { + s.onStreamCompletedImpl() +} + +var _ streamSender = &uniStreamSender{} + +type streamI interface { + Stream + closeForShutdown(error) + // for receiving + handleStreamFrame(*wire.StreamFrame) error + handleResetStreamFrame(*wire.ResetStreamFrame) error + getWindowUpdate() protocol.ByteCount + // for sending + hasData() bool + handleStopSendingFrame(*wire.StopSendingFrame) + popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) + updateSendWindow(protocol.ByteCount) +} + +var ( + _ receiveStreamI = (streamI)(nil) + _ sendStreamI = (streamI)(nil) +) + +// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface +// +// Read() and Write() may be called concurrently, but multiple calls to Read() or Write() individually must be synchronized manually. +type stream struct { + receiveStream + sendStream + + completedMutex sync.Mutex + sender streamSender + receiveStreamCompleted bool + sendStreamCompleted bool + + version protocol.VersionNumber +} + +var _ Stream = &stream{} + +// newStream creates a new Stream +func newStream(streamID protocol.StreamID, + sender streamSender, + flowController flowcontrol.StreamFlowController, + version protocol.VersionNumber, +) *stream { + s := &stream{sender: sender, version: version} + senderForSendStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.sendStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, + } + s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version) + senderForReceiveStream := &uniStreamSender{ + streamSender: sender, + onStreamCompletedImpl: func() { + s.completedMutex.Lock() + s.receiveStreamCompleted = true + s.checkIfCompleted() + s.completedMutex.Unlock() + }, + } + s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController, version) + return s +} + +// need to define StreamID() here, since both receiveStream and readStream have a StreamID() +func (s *stream) StreamID() protocol.StreamID { + // the result is same for receiveStream and sendStream + return s.sendStream.StreamID() +} + +func (s *stream) Close() error { + return s.sendStream.Close() +} + +func (s *stream) SetDeadline(t time.Time) error { + _ = s.SetReadDeadline(t) // SetReadDeadline never errors + _ = s.SetWriteDeadline(t) // SetWriteDeadline never errors + return nil +} + +// CloseForShutdown closes a stream abruptly. +// It makes Read and Write unblock (and return the error) immediately. +// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST. +func (s *stream) closeForShutdown(err error) { + s.sendStream.closeForShutdown(err) + s.receiveStream.closeForShutdown(err) +} + +// checkIfCompleted is called from the uniStreamSender, when one of the stream halves is completed. +// It makes sure that the onStreamCompleted callback is only called if both receive and send side have completed. +func (s *stream) checkIfCompleted() { + if s.sendStreamCompleted && s.receiveStreamCompleted { + s.sender.onStreamCompleted(s.StreamID()) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map.go b/vendor/github.com/lucas-clemente/quic-go/streams_map.go new file mode 100644 index 00000000..79c1ee91 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map.go @@ -0,0 +1,317 @@ +package quic + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type streamError struct { + message string + nums []protocol.StreamNum +} + +func (e streamError) Error() string { + return e.message +} + +func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error { + strError, ok := err.(streamError) + if !ok { + return err + } + ids := make([]interface{}, len(strError.nums)) + for i, num := range strError.nums { + ids[i] = num.StreamID(stype, pers) + } + return fmt.Errorf(strError.Error(), ids...) +} + +type streamOpenErr struct{ error } + +var _ net.Error = &streamOpenErr{} + +func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams } +func (streamOpenErr) Timeout() bool { return false } + +// errTooManyOpenStreams is used internally by the outgoing streams maps. +var errTooManyOpenStreams = errors.New("too many open streams") + +type streamsMap struct { + perspective protocol.Perspective + version protocol.VersionNumber + + maxIncomingBidiStreams uint64 + maxIncomingUniStreams uint64 + + sender streamSender + newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController + + mutex sync.Mutex + outgoingBidiStreams *outgoingBidiStreamsMap + outgoingUniStreams *outgoingUniStreamsMap + incomingBidiStreams *incomingBidiStreamsMap + incomingUniStreams *incomingUniStreamsMap + reset bool +} + +var _ streamManager = &streamsMap{} + +func newStreamsMap( + sender streamSender, + newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, + maxIncomingBidiStreams uint64, + maxIncomingUniStreams uint64, + perspective protocol.Perspective, + version protocol.VersionNumber, +) streamManager { + m := &streamsMap{ + perspective: perspective, + newFlowController: newFlowController, + maxIncomingBidiStreams: maxIncomingBidiStreams, + maxIncomingUniStreams: maxIncomingUniStreams, + sender: sender, + version: version, + } + m.initMaps() + return m +} + +func (m *streamsMap) initMaps() { + m.outgoingBidiStreams = newOutgoingBidiStreamsMap( + func(num protocol.StreamNum) streamI { + id := num.StreamID(protocol.StreamTypeBidi, m.perspective) + return newStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.sender.queueControlFrame, + ) + m.incomingBidiStreams = newIncomingBidiStreamsMap( + func(num protocol.StreamNum) streamI { + id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) + return newStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.maxIncomingBidiStreams, + m.sender.queueControlFrame, + ) + m.outgoingUniStreams = newOutgoingUniStreamsMap( + func(num protocol.StreamNum) sendStreamI { + id := num.StreamID(protocol.StreamTypeUni, m.perspective) + return newSendStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.sender.queueControlFrame, + ) + m.incomingUniStreams = newIncomingUniStreamsMap( + func(num protocol.StreamNum) receiveStreamI { + id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite()) + return newReceiveStream(id, m.sender, m.newFlowController(id), m.version) + }, + m.maxIncomingUniStreams, + m.sender.queueControlFrame, + ) +} + +func (m *streamsMap) OpenStream() (Stream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.outgoingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStream() + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) +} + +func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.outgoingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStreamSync(ctx) + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) +} + +func (m *streamsMap) OpenUniStream() (SendStream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.outgoingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStream() + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) +} + +func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.outgoingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.OpenStreamSync(ctx) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) +} + +func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.incomingBidiStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.AcceptStream(ctx) + return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) +} + +func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { + m.mutex.Lock() + reset := m.reset + mm := m.incomingUniStreams + m.mutex.Unlock() + if reset { + return nil, Err0RTTRejected + } + str, err := mm.AcceptStream(ctx) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) +} + +func (m *streamsMap) DeleteStream(id protocol.StreamID) error { + num := id.StreamNum() + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective) + } + return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite()) + case protocol.StreamTypeBidi: + if id.InitiatedBy() == m.perspective { + return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective) + } + return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite()) + } + panic("") +} + +func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + str, err := m.getOrOpenReceiveStream(id) + if err != nil { + return nil, &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: err.Error(), + } + } + return str, nil +} + +func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { + num := id.StreamNum() + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + // an outgoing unidirectional stream is a send stream, not a receive stream + return nil, fmt.Errorf("peer attempted to open receive stream %d", id) + } + str, err := m.incomingUniStreams.GetOrOpenStream(num) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) + case protocol.StreamTypeBidi: + var str receiveStreamI + var err error + if id.InitiatedBy() == m.perspective { + str, err = m.outgoingBidiStreams.GetStream(num) + } else { + str, err = m.incomingBidiStreams.GetOrOpenStream(num) + } + return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) + } + panic("") +} + +func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + str, err := m.getOrOpenSendStream(id) + if err != nil { + return nil, &qerr.TransportError{ + ErrorCode: qerr.StreamStateError, + ErrorMessage: err.Error(), + } + } + return str, nil +} + +func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { + num := id.StreamNum() + switch id.Type() { + case protocol.StreamTypeUni: + if id.InitiatedBy() == m.perspective { + str, err := m.outgoingUniStreams.GetStream(num) + return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) + } + // an incoming unidirectional stream is a receive stream, not a send stream + return nil, fmt.Errorf("peer attempted to open send stream %d", id) + case protocol.StreamTypeBidi: + var str sendStreamI + var err error + if id.InitiatedBy() == m.perspective { + str, err = m.outgoingBidiStreams.GetStream(num) + } else { + str, err = m.incomingBidiStreams.GetOrOpenStream(num) + } + return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) + } + panic("") +} + +func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { + switch f.Type { + case protocol.StreamTypeUni: + m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) + case protocol.StreamTypeBidi: + m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum) + } +} + +func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { + m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) + m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) + m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) + m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) +} + +func (m *streamsMap) CloseWithError(err error) { + m.outgoingBidiStreams.CloseWithError(err) + m.outgoingUniStreams.CloseWithError(err) + m.incomingBidiStreams.CloseWithError(err) + m.incomingUniStreams.CloseWithError(err) +} + +// ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are +// 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error. +// 2. reset to their initial state, such that we can immediately process new incoming stream data. +// Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error, +// until UseResetMaps() has been called. +func (m *streamsMap) ResetFor0RTT() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.reset = true + m.CloseWithError(Err0RTTRejected) + m.initMaps() +} + +func (m *streamsMap) UseResetMaps() { + m.mutex.Lock() + m.reset = false + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go new file mode 100644 index 00000000..26b56233 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go @@ -0,0 +1,18 @@ +package quic + +import ( + "github.com/cheekybits/genny/generic" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// In the auto-generated streams maps, we need to be able to close the streams. +// Therefore, extend the generic.Type with the stream close method. +// This definition must be in a file that Genny doesn't process. +type item interface { + generic.Type + updateSendWindow(protocol.ByteCount) + closeForShutdown(error) +} + +const streamTypeGeneric protocol.StreamType = protocol.StreamTypeUni diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go new file mode 100644 index 00000000..46c8c73a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go @@ -0,0 +1,192 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "context" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// When a stream is deleted before it was accepted, we can't delete it from the map immediately. +// We need to wait until the application accepts it, and delete it then. +type streamIEntry struct { + stream streamI + shouldDelete bool +} + +type incomingBidiStreamsMap struct { + mutex sync.RWMutex + newStreamChan chan struct{} + + streams map[protocol.StreamNum]streamIEntry + + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams + + newStream func(protocol.StreamNum) streamI + queueMaxStreamID func(*wire.MaxStreamsFrame) + + closeErr error +} + +func newIncomingBidiStreamsMap( + newStream func(protocol.StreamNum) streamI, + maxStreams uint64, + queueControlFrame func(wire.Frame), +) *incomingBidiStreamsMap { + return &incomingBidiStreamsMap{ + newStreamChan: make(chan struct{}, 1), + streams: make(map[protocol.StreamNum]streamIEntry), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, + newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + } +} + +func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + + m.mutex.Lock() + + var num protocol.StreamNum + var entry streamIEntry + for { + num = m.nextStreamToAccept + if m.closeErr != nil { + m.mutex.Unlock() + return nil, m.closeErr + } + var ok bool + entry, ok = m.streams[num] + if ok { + break + } + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() + } + m.nextStreamToAccept++ + // If this stream was completed before being accepted, we can delete it now. + if entry.shouldDelete { + if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() + return nil, err + } + } + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (streamI, error) { + m.mutex.RLock() + if num > m.maxStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } + } + // if the num is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if num < m.nextStreamToOpen { + var s streamI + // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. + if entry, ok := m.streams[num]; ok && !entry.shouldDelete { + s = entry.stream + } + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = streamIEntry{stream: m.newStream(newNum)} + select { + case m.newStreamChan <- struct{}{}: + default: + } + } + m.nextStreamToOpen = num + 1 + entry := m.streams[num] + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteStream(num) +} + +func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown incoming stream %d", + nums: []protocol.StreamNum{num}, + } + } + + // Don't delete this stream yet, if it was not yet accepted. + // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. + if num >= m.nextStreamToAccept { + entry, ok := m.streams[num] + if ok && entry.shouldDelete { + return streamError{ + message: "tried to delete incoming stream %d multiple times", + nums: []protocol.StreamNum{num}, + } + } + entry.shouldDelete = true + m.streams[num] = entry // can't assign to struct in map, so we need to reassign + return nil + } + + delete(m.streams, num) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if m.maxNumStreams > uint64(len(m.streams)) { + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeBidi, + MaxStreamNum: m.maxStream, + }) + } + } + return nil +} + +func (m *incomingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, entry := range m.streams { + entry.stream.closeForShutdown(err) + } + m.mutex.Unlock() + close(m.newStreamChan) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go new file mode 100644 index 00000000..4c7696a0 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go @@ -0,0 +1,190 @@ +package quic + +import ( + "context" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// When a stream is deleted before it was accepted, we can't delete it from the map immediately. +// We need to wait until the application accepts it, and delete it then. +type itemEntry struct { + stream item + shouldDelete bool +} + +//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" +//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" +type incomingItemsMap struct { + mutex sync.RWMutex + newStreamChan chan struct{} + + streams map[protocol.StreamNum]itemEntry + + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams + + newStream func(protocol.StreamNum) item + queueMaxStreamID func(*wire.MaxStreamsFrame) + + closeErr error +} + +func newIncomingItemsMap( + newStream func(protocol.StreamNum) item, + maxStreams uint64, + queueControlFrame func(wire.Frame), +) *incomingItemsMap { + return &incomingItemsMap{ + newStreamChan: make(chan struct{}, 1), + streams: make(map[protocol.StreamNum]itemEntry), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, + newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + } +} + +func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + + m.mutex.Lock() + + var num protocol.StreamNum + var entry itemEntry + for { + num = m.nextStreamToAccept + if m.closeErr != nil { + m.mutex.Unlock() + return nil, m.closeErr + } + var ok bool + entry, ok = m.streams[num] + if ok { + break + } + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() + } + m.nextStreamToAccept++ + // If this stream was completed before being accepted, we can delete it now. + if entry.shouldDelete { + if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() + return nil, err + } + } + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error) { + m.mutex.RLock() + if num > m.maxStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } + } + // if the num is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if num < m.nextStreamToOpen { + var s item + // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. + if entry, ok := m.streams[num]; ok && !entry.shouldDelete { + s = entry.stream + } + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = itemEntry{stream: m.newStream(newNum)} + select { + case m.newStreamChan <- struct{}{}: + default: + } + } + m.nextStreamToOpen = num + 1 + entry := m.streams[num] + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingItemsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteStream(num) +} + +func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown incoming stream %d", + nums: []protocol.StreamNum{num}, + } + } + + // Don't delete this stream yet, if it was not yet accepted. + // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. + if num >= m.nextStreamToAccept { + entry, ok := m.streams[num] + if ok && entry.shouldDelete { + return streamError{ + message: "tried to delete incoming stream %d multiple times", + nums: []protocol.StreamNum{num}, + } + } + entry.shouldDelete = true + m.streams[num] = entry // can't assign to struct in map, so we need to reassign + return nil + } + + delete(m.streams, num) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if m.maxNumStreams > uint64(len(m.streams)) { + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: streamTypeGeneric, + MaxStreamNum: m.maxStream, + }) + } + } + return nil +} + +func (m *incomingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, entry := range m.streams { + entry.stream.closeForShutdown(err) + } + m.mutex.Unlock() + close(m.newStreamChan) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go new file mode 100644 index 00000000..5bddec00 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go @@ -0,0 +1,192 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "context" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +// When a stream is deleted before it was accepted, we can't delete it from the map immediately. +// We need to wait until the application accepts it, and delete it then. +type receiveStreamIEntry struct { + stream receiveStreamI + shouldDelete bool +} + +type incomingUniStreamsMap struct { + mutex sync.RWMutex + newStreamChan chan struct{} + + streams map[protocol.StreamNum]receiveStreamIEntry + + nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() + nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened + maxStream protocol.StreamNum // the highest stream that the peer is allowed to open + maxNumStreams uint64 // maximum number of streams + + newStream func(protocol.StreamNum) receiveStreamI + queueMaxStreamID func(*wire.MaxStreamsFrame) + + closeErr error +} + +func newIncomingUniStreamsMap( + newStream func(protocol.StreamNum) receiveStreamI, + maxStreams uint64, + queueControlFrame func(wire.Frame), +) *incomingUniStreamsMap { + return &incomingUniStreamsMap{ + newStreamChan: make(chan struct{}, 1), + streams: make(map[protocol.StreamNum]receiveStreamIEntry), + maxStream: protocol.StreamNum(maxStreams), + maxNumStreams: maxStreams, + newStream: newStream, + nextStreamToOpen: 1, + nextStreamToAccept: 1, + queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, + } +} + +func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) { + // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist + select { + case <-m.newStreamChan: + default: + } + + m.mutex.Lock() + + var num protocol.StreamNum + var entry receiveStreamIEntry + for { + num = m.nextStreamToAccept + if m.closeErr != nil { + m.mutex.Unlock() + return nil, m.closeErr + } + var ok bool + entry, ok = m.streams[num] + if ok { + break + } + m.mutex.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.newStreamChan: + } + m.mutex.Lock() + } + m.nextStreamToAccept++ + // If this stream was completed before being accepted, we can delete it now. + if entry.shouldDelete { + if err := m.deleteStream(num); err != nil { + m.mutex.Unlock() + return nil, err + } + } + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receiveStreamI, error) { + m.mutex.RLock() + if num > m.maxStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer tried to open stream %d (current limit: %d)", + nums: []protocol.StreamNum{num, m.maxStream}, + } + } + // if the num is smaller than the highest we accepted + // * this stream exists in the map, and we can return it, or + // * this stream was already closed, then we can return the nil + if num < m.nextStreamToOpen { + var s receiveStreamI + // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. + if entry, ok := m.streams[num]; ok && !entry.shouldDelete { + s = entry.stream + } + m.mutex.RUnlock() + return s, nil + } + m.mutex.RUnlock() + + m.mutex.Lock() + // no need to check the two error conditions from above again + // * maxStream can only increase, so if the id was valid before, it definitely is valid now + // * highestStream is only modified by this function + for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { + m.streams[newNum] = receiveStreamIEntry{stream: m.newStream(newNum)} + select { + case m.newStreamChan <- struct{}{}: + default: + } + } + m.nextStreamToOpen = num + 1 + entry := m.streams[num] + m.mutex.Unlock() + return entry.stream, nil +} + +func (m *incomingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.deleteStream(num) +} + +func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown incoming stream %d", + nums: []protocol.StreamNum{num}, + } + } + + // Don't delete this stream yet, if it was not yet accepted. + // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. + if num >= m.nextStreamToAccept { + entry, ok := m.streams[num] + if ok && entry.shouldDelete { + return streamError{ + message: "tried to delete incoming stream %d multiple times", + nums: []protocol.StreamNum{num}, + } + } + entry.shouldDelete = true + m.streams[num] = entry // can't assign to struct in map, so we need to reassign + return nil + } + + delete(m.streams, num) + // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream + if m.maxNumStreams > uint64(len(m.streams)) { + maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 + // Never send a value larger than protocol.MaxStreamCount. + if maxStream <= protocol.MaxStreamCount { + m.maxStream = maxStream + m.queueMaxStreamID(&wire.MaxStreamsFrame{ + Type: protocol.StreamTypeUni, + MaxStreamNum: m.maxStream, + }) + } + } + return nil +} + +func (m *incomingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, entry := range m.streams { + entry.stream.closeForShutdown(err) + } + m.mutex.Unlock() + close(m.newStreamChan) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go new file mode 100644 index 00000000..3f7ec166 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go @@ -0,0 +1,226 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "context" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type outgoingBidiStreamsMap struct { + mutex sync.RWMutex + + streams map[protocol.StreamNum]streamI + + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + + newStream func(protocol.StreamNum) streamI + queueStreamIDBlocked func(*wire.StreamsBlockedFrame) + + closeErr error +} + +func newOutgoingBidiStreamsMap( + newStream func(protocol.StreamNum) streamI, + queueControlFrame func(wire.Frame), +) *outgoingBidiStreamsMap { + return &outgoingBidiStreamsMap{ + streams: make(map[protocol.StreamNum]streamI), + openQueue: make(map[uint64]chan struct{}), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, + } +} + +func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + // if there are OpenStreamSync calls waiting, return an error here + if len(m.openQueue) > 0 || m.nextStream > m.maxStream { + m.maybeSendBlockedFrame() + return nil, streamOpenErr{errTooManyOpenStreams} + } + return m.openStream(), nil +} + +func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + if err := ctx.Err(); err != nil { + return nil, err + } + + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { + return m.openStream(), nil + } + + waitChan := make(chan struct{}, 1) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan + m.maybeSendBlockedFrame() + + for { + m.mutex.Unlock() + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } + m.mutex.Lock() + + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + // no stream available. Continue waiting + continue + } + str := m.openStream() + delete(m.openQueue, queuePos) + m.lowestInQueue = queuePos + 1 + m.unblockOpenSync() + return str, nil + } +} + +func (m *outgoingBidiStreamsMap) openStream() streamI { + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream++ + return s +} + +// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, +// if we haven't sent one for this offset yet +func (m *outgoingBidiStreamsMap) maybeSendBlockedFrame() { + if m.blockedSent { + return + } + + var streamNum protocol.StreamNum + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream + } + m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ + Type: protocol.StreamTypeBidi, + StreamLimit: streamNum, + }) + m.blockedSent = true +} + +func (m *outgoingBidiStreamsMap) GetStream(num protocol.StreamNum) (streamI, error) { + m.mutex.RLock() + if num >= m.nextStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } + } + s := m.streams[num] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown outgoing stream %d", + nums: []protocol.StreamNum{num}, + } + } + delete(m.streams, num) + return nil +} + +func (m *outgoingBidiStreamsMap) SetMaxStream(num protocol.StreamNum) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if num <= m.maxStream { + return + } + m.maxStream = num + m.blockedSent = false + if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { + m.maybeSendBlockedFrame() + } + m.unblockOpenSync() +} + +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingBidiStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + +// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream +func (m *outgoingBidiStreamsMap) unblockOpenSync() { + if len(m.openQueue) == 0 { + return + } + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. + // It's sufficient to only unblock OpenStreamSync once. + select { + case c <- struct{}{}: + default: + } + return + } +} + +func (m *outgoingBidiStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + for _, c := range m.openQueue { + if c != nil { + close(c) + } + } + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go new file mode 100644 index 00000000..dde75043 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go @@ -0,0 +1,224 @@ +package quic + +import ( + "context" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" +//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" +type outgoingItemsMap struct { + mutex sync.RWMutex + + streams map[protocol.StreamNum]item + + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + + newStream func(protocol.StreamNum) item + queueStreamIDBlocked func(*wire.StreamsBlockedFrame) + + closeErr error +} + +func newOutgoingItemsMap( + newStream func(protocol.StreamNum) item, + queueControlFrame func(wire.Frame), +) *outgoingItemsMap { + return &outgoingItemsMap{ + streams: make(map[protocol.StreamNum]item), + openQueue: make(map[uint64]chan struct{}), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, + } +} + +func (m *outgoingItemsMap) OpenStream() (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + // if there are OpenStreamSync calls waiting, return an error here + if len(m.openQueue) > 0 || m.nextStream > m.maxStream { + m.maybeSendBlockedFrame() + return nil, streamOpenErr{errTooManyOpenStreams} + } + return m.openStream(), nil +} + +func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + if err := ctx.Err(); err != nil { + return nil, err + } + + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { + return m.openStream(), nil + } + + waitChan := make(chan struct{}, 1) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan + m.maybeSendBlockedFrame() + + for { + m.mutex.Unlock() + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } + m.mutex.Lock() + + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + // no stream available. Continue waiting + continue + } + str := m.openStream() + delete(m.openQueue, queuePos) + m.lowestInQueue = queuePos + 1 + m.unblockOpenSync() + return str, nil + } +} + +func (m *outgoingItemsMap) openStream() item { + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream++ + return s +} + +// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, +// if we haven't sent one for this offset yet +func (m *outgoingItemsMap) maybeSendBlockedFrame() { + if m.blockedSent { + return + } + + var streamNum protocol.StreamNum + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream + } + m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ + Type: streamTypeGeneric, + StreamLimit: streamNum, + }) + m.blockedSent = true +} + +func (m *outgoingItemsMap) GetStream(num protocol.StreamNum) (item, error) { + m.mutex.RLock() + if num >= m.nextStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } + } + s := m.streams[num] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown outgoing stream %d", + nums: []protocol.StreamNum{num}, + } + } + delete(m.streams, num) + return nil +} + +func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if num <= m.maxStream { + return + } + m.maxStream = num + m.blockedSent = false + if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { + m.maybeSendBlockedFrame() + } + m.unblockOpenSync() +} + +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingItemsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + +// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream +func (m *outgoingItemsMap) unblockOpenSync() { + if len(m.openQueue) == 0 { + return + } + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. + // It's sufficient to only unblock OpenStreamSync once. + select { + case c <- struct{}{}: + default: + } + return + } +} + +func (m *outgoingItemsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + for _, c := range m.openQueue { + if c != nil { + close(c) + } + } + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go new file mode 100644 index 00000000..8782364a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go @@ -0,0 +1,226 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package quic + +import ( + "context" + "sync" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type outgoingUniStreamsMap struct { + mutex sync.RWMutex + + streams map[protocol.StreamNum]sendStreamI + + openQueue map[uint64]chan struct{} + lowestInQueue uint64 + highestInQueue uint64 + + nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) + maxStream protocol.StreamNum // the maximum stream ID we're allowed to open + blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream + + newStream func(protocol.StreamNum) sendStreamI + queueStreamIDBlocked func(*wire.StreamsBlockedFrame) + + closeErr error +} + +func newOutgoingUniStreamsMap( + newStream func(protocol.StreamNum) sendStreamI, + queueControlFrame func(wire.Frame), +) *outgoingUniStreamsMap { + return &outgoingUniStreamsMap{ + streams: make(map[protocol.StreamNum]sendStreamI), + openQueue: make(map[uint64]chan struct{}), + maxStream: protocol.InvalidStreamNum, + nextStream: 1, + newStream: newStream, + queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, + } +} + +func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + // if there are OpenStreamSync calls waiting, return an error here + if len(m.openQueue) > 0 || m.nextStream > m.maxStream { + m.maybeSendBlockedFrame() + return nil, streamOpenErr{errTooManyOpenStreams} + } + return m.openStream(), nil +} + +func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closeErr != nil { + return nil, m.closeErr + } + + if err := ctx.Err(); err != nil { + return nil, err + } + + if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { + return m.openStream(), nil + } + + waitChan := make(chan struct{}, 1) + queuePos := m.highestInQueue + m.highestInQueue++ + if len(m.openQueue) == 0 { + m.lowestInQueue = queuePos + } + m.openQueue[queuePos] = waitChan + m.maybeSendBlockedFrame() + + for { + m.mutex.Unlock() + select { + case <-ctx.Done(): + m.mutex.Lock() + delete(m.openQueue, queuePos) + return nil, ctx.Err() + case <-waitChan: + } + m.mutex.Lock() + + if m.closeErr != nil { + return nil, m.closeErr + } + if m.nextStream > m.maxStream { + // no stream available. Continue waiting + continue + } + str := m.openStream() + delete(m.openQueue, queuePos) + m.lowestInQueue = queuePos + 1 + m.unblockOpenSync() + return str, nil + } +} + +func (m *outgoingUniStreamsMap) openStream() sendStreamI { + s := m.newStream(m.nextStream) + m.streams[m.nextStream] = s + m.nextStream++ + return s +} + +// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, +// if we haven't sent one for this offset yet +func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() { + if m.blockedSent { + return + } + + var streamNum protocol.StreamNum + if m.maxStream != protocol.InvalidStreamNum { + streamNum = m.maxStream + } + m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ + Type: protocol.StreamTypeUni, + StreamLimit: streamNum, + }) + m.blockedSent = true +} + +func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) { + m.mutex.RLock() + if num >= m.nextStream { + m.mutex.RUnlock() + return nil, streamError{ + message: "peer attempted to open stream %d", + nums: []protocol.StreamNum{num}, + } + } + s := m.streams[num] + m.mutex.RUnlock() + return s, nil +} + +func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.streams[num]; !ok { + return streamError{ + message: "tried to delete unknown outgoing stream %d", + nums: []protocol.StreamNum{num}, + } + } + delete(m.streams, num) + return nil +} + +func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if num <= m.maxStream { + return + } + m.maxStream = num + m.blockedSent = false + if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { + m.maybeSendBlockedFrame() + } + m.unblockOpenSync() +} + +// UpdateSendWindow is called when the peer's transport parameters are received. +// Only in the case of a 0-RTT handshake will we have open streams at this point. +// We might need to update the send window, in case the server increased it. +func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { + m.mutex.Lock() + for _, str := range m.streams { + str.updateSendWindow(limit) + } + m.mutex.Unlock() +} + +// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream +func (m *outgoingUniStreamsMap) unblockOpenSync() { + if len(m.openQueue) == 0 { + return + } + for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { + c, ok := m.openQueue[qp] + if !ok { // entry was deleted because the context was canceled + continue + } + // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. + // It's sufficient to only unblock OpenStreamSync once. + select { + case c <- struct{}{}: + default: + } + return + } +} + +func (m *outgoingUniStreamsMap) CloseWithError(err error) { + m.mutex.Lock() + m.closeErr = err + for _, str := range m.streams { + str.closeForShutdown(err) + } + for _, c := range m.openQueue { + if c != nil { + close(c) + } + } + m.mutex.Unlock() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn.go new file mode 100644 index 00000000..7cc05465 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn.go @@ -0,0 +1,80 @@ +package quic + +import ( + "net" + "syscall" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +// OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header. +// If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will use it. +// In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets. +type OOBCapablePacketConn interface { + net.PacketConn + SyscallConn() (syscall.RawConn, error) + ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) + WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) +} + +var _ OOBCapablePacketConn = &net.UDPConn{} + +func wrapConn(pc net.PacketConn) (rawConn, error) { + conn, ok := pc.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if ok { + rawConn, err := conn.SyscallConn() + if err != nil { + return nil, err + } + + if _, ok := pc.LocalAddr().(*net.UDPAddr); ok { + // Only set DF on sockets that we expect to be able to handle that configuration. + err = setDF(rawConn) + if err != nil { + return nil, err + } + } + } + c, ok := pc.(OOBCapablePacketConn) + if !ok { + utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") + return &basicConn{PacketConn: pc}, nil + } + return newConn(c) +} + +// The basicConn is the most trivial implementation of a connection. +// It reads a single packet from the underlying net.PacketConn. +// It is used when +// * the net.PacketConn is not a OOBCapablePacketConn, and +// * when the OS doesn't support OOB. +type basicConn struct { + net.PacketConn +} + +var _ rawConn = &basicConn{} + +func (c *basicConn) ReadPacket() (*receivedPacket, error) { + buffer := getPacketBuffer() + // The packet size should not exceed protocol.MaxPacketBufferSize bytes + // If it does, we only read a truncated packet, which will then end up undecryptable + buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] + n, addr, err := c.PacketConn.ReadFrom(buffer.Data) + if err != nil { + return nil, err + } + return &receivedPacket{ + remoteAddr: addr, + rcvTime: time.Now(), + data: buffer.Data[:n], + buffer: buffer, + }, nil +} + +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { + return c.PacketConn.WriteTo(b, addr) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_df.go new file mode 100644 index 00000000..ae9274d9 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_df.go @@ -0,0 +1,16 @@ +//go:build !linux && !windows +// +build !linux,!windows + +package quic + +import "syscall" + +func setDF(rawConn syscall.RawConn) error { + // no-op on unsupported platforms + return nil +} + +func isMsgSizeErr(err error) bool { + // to be implemented for more specific platforms + return false +} diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_linux.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_linux.go new file mode 100644 index 00000000..17ac67f1 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_linux.go @@ -0,0 +1,40 @@ +//go:build linux +// +build linux + +package quic + +import ( + "errors" + "syscall" + + "github.com/lucas-clemente/quic-go/internal/utils" + "golang.org/x/sys/unix" +) + +func setDF(rawConn syscall.RawConn) error { + // Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long" + // and the datagram will not be fragmented + var errDFIPv4, errDFIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) + errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO) + }); err != nil { + return err + } + switch { + case errDFIPv4 == nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") + case errDFIPv4 == nil && errDFIPv6 != nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4.") + case errDFIPv4 != nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv6.") + case errDFIPv4 != nil && errDFIPv6 != nil: + return errors.New("setting DF failed for both IPv4 and IPv6") + } + return nil +} + +func isMsgSizeErr(err error) bool { + // https://man7.org/linux/man-pages/man7/udp.7.html + return errors.Is(err, unix.EMSGSIZE) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_windows.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_windows.go new file mode 100644 index 00000000..4649f646 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_windows.go @@ -0,0 +1,46 @@ +//go:build windows +// +build windows + +package quic + +import ( + "errors" + "syscall" + + "github.com/lucas-clemente/quic-go/internal/utils" + "golang.org/x/sys/windows" +) + +const ( + // same for both IPv4 and IPv6 on Windows + // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html + // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html + IP_DONTFRAGMENT = 14 + IPV6_DONTFRAG = 14 +) + +func setDF(rawConn syscall.RawConn) error { + var errDFIPv4, errDFIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) + errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) + }); err != nil { + return err + } + switch { + case errDFIPv4 == nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") + case errDFIPv4 == nil && errDFIPv6 != nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4.") + case errDFIPv4 != nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv6.") + case errDFIPv4 != nil && errDFIPv6 != nil: + return errors.New("setting DF failed for both IPv4 and IPv6") + } + return nil +} + +func isMsgSizeErr(err error) bool { + // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 + return errors.Is(err, windows.WSAEMSGSIZE) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_darwin.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_darwin.go new file mode 100644 index 00000000..eabf489f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_darwin.go @@ -0,0 +1,22 @@ +//go:build darwin +// +build darwin + +package quic + +import "golang.org/x/sys/unix" + +const msgTypeIPTOS = unix.IP_RECVTOS + +const ( + ipv4RECVPKTINFO = unix.IP_RECVPKTINFO + ipv6RECVPKTINFO = 0x3d +) + +const ( + msgTypeIPv4PKTINFO = unix.IP_PKTINFO + msgTypeIPv6PKTINFO = 0x2e +) + +// ReadBatch only returns a single packet on OSX, +// see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch. +const batchSize = 1 diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_freebsd.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_freebsd.go new file mode 100644 index 00000000..0b3e8434 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_freebsd.go @@ -0,0 +1,22 @@ +//go:build freebsd +// +build freebsd + +package quic + +import "golang.org/x/sys/unix" + +const ( + msgTypeIPTOS = unix.IP_RECVTOS +) + +const ( + ipv4RECVPKTINFO = 0x7 + ipv6RECVPKTINFO = 0x24 +) + +const ( + msgTypeIPv4PKTINFO = 0x7 + msgTypeIPv6PKTINFO = 0x2e +) + +const batchSize = 8 diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_linux.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_linux.go new file mode 100644 index 00000000..51bec900 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_linux.go @@ -0,0 +1,20 @@ +//go:build linux +// +build linux + +package quic + +import "golang.org/x/sys/unix" + +const msgTypeIPTOS = unix.IP_TOS + +const ( + ipv4RECVPKTINFO = unix.IP_PKTINFO + ipv6RECVPKTINFO = unix.IPV6_RECVPKTINFO +) + +const ( + msgTypeIPv4PKTINFO = unix.IP_PKTINFO + msgTypeIPv6PKTINFO = unix.IPV6_PKTINFO +) + +const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_no_oob.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_no_oob.go new file mode 100644 index 00000000..e3b0d11f --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_no_oob.go @@ -0,0 +1,16 @@ +//go:build !darwin && !linux && !freebsd && !windows +// +build !darwin,!linux,!freebsd,!windows + +package quic + +import "net" + +func newConn(c net.PacketConn) (rawConn, error) { + return &basicConn{PacketConn: c}, nil +} + +func inspectReadBuffer(interface{}) (int, error) { + return 0, nil +} + +func (i *packetInfo) OOB() []byte { return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_oob.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_oob.go new file mode 100644 index 00000000..acd74d02 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_oob.go @@ -0,0 +1,257 @@ +//go:build darwin || linux || freebsd +// +build darwin linux freebsd + +package quic + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "syscall" + "time" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" +) + +const ( + ecnMask = 0x3 + oobBufferSize = 128 +) + +// Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version. +// They're both just aliases for x/net/internal/socket.Message. +// This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages. +var _ ipv4.Message = ipv6.Message{} + +type batchConn interface { + ReadBatch(ms []ipv4.Message, flags int) (int, error) +} + +func inspectReadBuffer(c interface{}) (int, error) { + conn, ok := c.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if !ok { + return 0, errors.New("doesn't have a SyscallConn") + } + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) + } + var size int + var serr error + if err := rawConn.Control(func(fd uintptr) { + size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) + }); err != nil { + return 0, err + } + return size, serr +} + +type oobConn struct { + OOBCapablePacketConn + batchConn batchConn + + readPos uint8 + // Packets received from the kernel, but not yet returned by ReadPacket(). + messages []ipv4.Message + buffers [batchSize]*packetBuffer +} + +var _ rawConn = &oobConn{} + +func newConn(c OOBCapablePacketConn) (*oobConn, error) { + rawConn, err := c.SyscallConn() + if err != nil { + return nil, err + } + needsPacketInfo := false + if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() { + needsPacketInfo = true + } + // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection. + // Try enabling receiving of ECN and packet info for both IP versions. + // We expect at least one of those syscalls to succeed. + var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1) + errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) + + if needsPacketInfo { + errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1) + errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1) + } + }); err != nil { + return nil, err + } + switch { + case errECNIPv4 == nil && errECNIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.") + case errECNIPv4 == nil && errECNIPv6 != nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.") + case errECNIPv4 != nil && errECNIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.") + case errECNIPv4 != nil && errECNIPv6 != nil: + return nil, errors.New("activating ECN failed for both IPv4 and IPv6") + } + if needsPacketInfo { + switch { + case errPIIPv4 == nil && errPIIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.") + case errPIIPv4 == nil && errPIIPv6 != nil: + utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.") + case errPIIPv4 != nil && errPIIPv6 == nil: + utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.") + case errPIIPv4 != nil && errPIIPv6 != nil: + return nil, errors.New("activating packet info failed for both IPv4 and IPv6") + } + } + + // Allows callers to pass in a connection that already satisfies batchConn interface + // to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor + // via SyscallConn(), and read it that way, which might not be what the caller wants. + var bc batchConn + if ibc, ok := c.(batchConn); ok { + bc = ibc + } else { + bc = ipv4.NewPacketConn(c) + } + + oobConn := &oobConn{ + OOBCapablePacketConn: c, + batchConn: bc, + messages: make([]ipv4.Message, batchSize), + readPos: batchSize, + } + for i := 0; i < batchSize; i++ { + oobConn.messages[i].OOB = make([]byte, oobBufferSize) + } + return oobConn, nil +} + +func (c *oobConn) ReadPacket() (*receivedPacket, error) { + if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. + c.messages = c.messages[:batchSize] + // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call + for i := uint8(0); i < c.readPos; i++ { + buffer := getPacketBuffer() + buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] + c.buffers[i] = buffer + c.messages[i].Buffers = [][]byte{c.buffers[i].Data} + } + c.readPos = 0 + + n, err := c.batchConn.ReadBatch(c.messages, 0) + if n == 0 || err != nil { + return nil, err + } + c.messages = c.messages[:n] + } + + msg := c.messages[c.readPos] + buffer := c.buffers[c.readPos] + c.readPos++ + ctrlMsgs, err := unix.ParseSocketControlMessage(msg.OOB[:msg.NN]) + if err != nil { + return nil, err + } + var ecn protocol.ECN + var destIP net.IP + var ifIndex uint32 + for _, ctrlMsg := range ctrlMsgs { + if ctrlMsg.Header.Level == unix.IPPROTO_IP { + switch ctrlMsg.Header.Type { + case msgTypeIPTOS: + ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) + case msgTypeIPv4PKTINFO: + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* Interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* Header Destination + // address */ + // }; + ip := make([]byte, 4) + if len(ctrlMsg.Data) == 12 { + ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data) + copy(ip, ctrlMsg.Data[8:12]) + } else if len(ctrlMsg.Data) == 4 { + // FreeBSD + copy(ip, ctrlMsg.Data) + } + destIP = net.IP(ip) + } + } + if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 { + switch ctrlMsg.Header.Type { + case unix.IPV6_TCLASS: + ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) + case msgTypeIPv6PKTINFO: + // struct in6_pktinfo { + // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ + // unsigned int ipi6_ifindex; /* send/recv interface index */ + // }; + if len(ctrlMsg.Data) == 20 { + ip := make([]byte, 16) + copy(ip, ctrlMsg.Data[:16]) + destIP = net.IP(ip) + ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:]) + } + } + } + } + var info *packetInfo + if destIP != nil { + info = &packetInfo{ + addr: destIP, + ifIndex: ifIndex, + } + } + return &receivedPacket{ + remoteAddr: msg.Addr, + rcvTime: time.Now(), + data: msg.Buffers[0][:msg.N], + ecn: ecn, + info: info, + buffer: buffer, + }, nil +} + +func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) { + n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) + return n, err +} + +func (info *packetInfo) OOB() []byte { + if info == nil { + return nil + } + if ip4 := info.addr.To4(); ip4 != nil { + // struct in_pktinfo { + // unsigned int ipi_ifindex; /* Interface index */ + // struct in_addr ipi_spec_dst; /* Local address */ + // struct in_addr ipi_addr; /* Header Destination address */ + // }; + cm := ipv4.ControlMessage{ + Src: ip4, + IfIndex: int(info.ifIndex), + } + return cm.Marshal() + } else if len(info.addr) == 16 { + // struct in6_pktinfo { + // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ + // unsigned int ipi6_ifindex; /* send/recv interface index */ + // }; + cm := ipv6.ControlMessage{ + Src: info.addr, + IfIndex: int(info.ifIndex), + } + return cm.Marshal() + } + return nil +} diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_windows.go b/vendor/github.com/lucas-clemente/quic-go/sys_conn_windows.go new file mode 100644 index 00000000..f2cc22ab --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/sys_conn_windows.go @@ -0,0 +1,40 @@ +//go:build windows +// +build windows + +package quic + +import ( + "errors" + "fmt" + "net" + "syscall" + + "golang.org/x/sys/windows" +) + +func newConn(c OOBCapablePacketConn) (rawConn, error) { + return &basicConn{PacketConn: c}, nil +} + +func inspectReadBuffer(c net.PacketConn) (int, error) { + conn, ok := c.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if !ok { + return 0, errors.New("doesn't have a SyscallConn") + } + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) + } + var size int + var serr error + if err := rawConn.Control(func(fd uintptr) { + size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF) + }); err != nil { + return 0, err + } + return size, serr +} + +func (i *packetInfo) OOB() []byte { return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/token_store.go b/vendor/github.com/lucas-clemente/quic-go/token_store.go new file mode 100644 index 00000000..9641dc5a --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/token_store.go @@ -0,0 +1,117 @@ +package quic + +import ( + "container/list" + "sync" + + "github.com/lucas-clemente/quic-go/internal/utils" +) + +type singleOriginTokenStore struct { + tokens []*ClientToken + len int + p int +} + +func newSingleOriginTokenStore(size int) *singleOriginTokenStore { + return &singleOriginTokenStore{tokens: make([]*ClientToken, size)} +} + +func (s *singleOriginTokenStore) Add(token *ClientToken) { + s.tokens[s.p] = token + s.p = s.index(s.p + 1) + s.len = utils.Min(s.len+1, len(s.tokens)) +} + +func (s *singleOriginTokenStore) Pop() *ClientToken { + s.p = s.index(s.p - 1) + token := s.tokens[s.p] + s.tokens[s.p] = nil + s.len = utils.Max(s.len-1, 0) + return token +} + +func (s *singleOriginTokenStore) Len() int { + return s.len +} + +func (s *singleOriginTokenStore) index(i int) int { + mod := len(s.tokens) + return (i + mod) % mod +} + +type lruTokenStoreEntry struct { + key string + cache *singleOriginTokenStore +} + +type lruTokenStore struct { + mutex sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int + singleOriginSize int +} + +var _ TokenStore = &lruTokenStore{} + +// NewLRUTokenStore creates a new LRU cache for tokens received by the client. +// maxOrigins specifies how many origins this cache is saving tokens for. +// tokensPerOrigin specifies the maximum number of tokens per origin. +func NewLRUTokenStore(maxOrigins, tokensPerOrigin int) TokenStore { + return &lruTokenStore{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: maxOrigins, + singleOriginSize: tokensPerOrigin, + } +} + +func (s *lruTokenStore) Put(key string, token *ClientToken) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if el, ok := s.m[key]; ok { + entry := el.Value.(*lruTokenStoreEntry) + entry.cache.Add(token) + s.q.MoveToFront(el) + return + } + + if s.q.Len() < s.capacity { + entry := &lruTokenStoreEntry{ + key: key, + cache: newSingleOriginTokenStore(s.singleOriginSize), + } + entry.cache.Add(token) + s.m[key] = s.q.PushFront(entry) + return + } + + elem := s.q.Back() + entry := elem.Value.(*lruTokenStoreEntry) + delete(s.m, entry.key) + entry.key = key + entry.cache = newSingleOriginTokenStore(s.singleOriginSize) + entry.cache.Add(token) + s.q.MoveToFront(elem) + s.m[key] = elem +} + +func (s *lruTokenStore) Pop(key string) *ClientToken { + s.mutex.Lock() + defer s.mutex.Unlock() + + var token *ClientToken + if el, ok := s.m[key]; ok { + s.q.MoveToFront(el) + cache := el.Value.(*lruTokenStoreEntry).cache + token = cache.Pop() + if cache.Len() == 0 { + s.q.Remove(el) + delete(s.m, key) + } + } + return token +} diff --git a/vendor/github.com/lucas-clemente/quic-go/tools.go b/vendor/github.com/lucas-clemente/quic-go/tools.go new file mode 100644 index 00000000..ee68fafb --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/tools.go @@ -0,0 +1,9 @@ +//go:build tools +// +build tools + +package quic + +import ( + _ "github.com/cheekybits/genny" + _ "github.com/onsi/ginkgo/ginkgo" +) diff --git a/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go new file mode 100644 index 00000000..2abcf673 --- /dev/null +++ b/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go @@ -0,0 +1,71 @@ +package quic + +import ( + "sync" + + "github.com/lucas-clemente/quic-go/internal/flowcontrol" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +type windowUpdateQueue struct { + mutex sync.Mutex + + queue map[protocol.StreamID]struct{} // used as a set + queuedConn bool // connection-level window update + + streamGetter streamGetter + connFlowController flowcontrol.ConnectionFlowController + callback func(wire.Frame) +} + +func newWindowUpdateQueue( + streamGetter streamGetter, + connFC flowcontrol.ConnectionFlowController, + cb func(wire.Frame), +) *windowUpdateQueue { + return &windowUpdateQueue{ + queue: make(map[protocol.StreamID]struct{}), + streamGetter: streamGetter, + connFlowController: connFC, + callback: cb, + } +} + +func (q *windowUpdateQueue) AddStream(id protocol.StreamID) { + q.mutex.Lock() + q.queue[id] = struct{}{} + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) AddConnection() { + q.mutex.Lock() + q.queuedConn = true + q.mutex.Unlock() +} + +func (q *windowUpdateQueue) QueueAll() { + q.mutex.Lock() + // queue a connection-level window update + if q.queuedConn { + q.callback(&wire.MaxDataFrame{MaximumData: q.connFlowController.GetWindowUpdate()}) + q.queuedConn = false + } + // queue all stream-level window updates + for id := range q.queue { + delete(q.queue, id) + str, err := q.streamGetter.GetOrOpenReceiveStream(id) + if err != nil || str == nil { // the stream can be nil if it was completed before dequeing the window update + continue + } + offset := str.getWindowUpdate() + if offset == 0 { // can happen if we received a final offset, right after queueing the window update + continue + } + q.callback(&wire.MaxStreamDataFrame{ + StreamID: id, + MaximumStreamData: offset, + }) + } + q.mutex.Unlock() +} diff --git a/vendor/github.com/marten-seemann/qpack/.codecov.yml b/vendor/github.com/marten-seemann/qpack/.codecov.yml new file mode 100644 index 00000000..00064af3 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/.codecov.yml @@ -0,0 +1,7 @@ +coverage: + round: nearest + status: + project: + default: + threshold: 1 + patch: false diff --git a/vendor/github.com/marten-seemann/qpack/.gitignore b/vendor/github.com/marten-seemann/qpack/.gitignore new file mode 100644 index 00000000..66c189a0 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/.gitignore @@ -0,0 +1,6 @@ +fuzzing/*.zip +fuzzing/coverprofile +fuzzing/crashers +fuzzing/sonarprofile +fuzzing/suppressions +fuzzing/corpus/ diff --git a/vendor/github.com/marten-seemann/qpack/.gitmodules b/vendor/github.com/marten-seemann/qpack/.gitmodules new file mode 100644 index 00000000..5ac16f08 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/.gitmodules @@ -0,0 +1,3 @@ +[submodule "integrationtests/interop/qifs"] + path = integrationtests/interop/qifs + url = https://github.com/qpackers/qifs.git diff --git a/vendor/github.com/marten-seemann/qpack/.golangci.yml b/vendor/github.com/marten-seemann/qpack/.golangci.yml new file mode 100644 index 00000000..4a91adc7 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/.golangci.yml @@ -0,0 +1,27 @@ +run: +linters-settings: +linters: + disable-all: true + enable: + - asciicheck + - deadcode + - exhaustive + - exportloopref + - goconst + - gofmt # redundant, since gofmt *should* be a no-op after gofumpt + - gofumpt + - goimports + - gosimple + - ineffassign + - misspell + - prealloc + - scopelint + - staticcheck + - stylecheck + - structcheck + - unconvert + - unparam + - unused + - varcheck + - vet + diff --git a/vendor/github.com/marten-seemann/qpack/LICENSE.md b/vendor/github.com/marten-seemann/qpack/LICENSE.md new file mode 100644 index 00000000..1ac5a2d9 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/LICENSE.md @@ -0,0 +1,7 @@ +Copyright 2019 Marten Seemann + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/marten-seemann/qpack/README.md b/vendor/github.com/marten-seemann/qpack/README.md new file mode 100644 index 00000000..a621825a --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/README.md @@ -0,0 +1,21 @@ +# QPACK + +[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](https://godoc.org/github.com/marten-seemann/qpack) +[![CircleCI Build Status](https://img.shields.io/circleci/project/github/marten-seemann/qpack.svg?style=flat-square&label=CircleCI+build)](https://circleci.com/gh/marten-seemann/qpack) +[![Code Coverage](https://img.shields.io/codecov/c/github/marten-seemann/qpack/master.svg?style=flat-square)](https://codecov.io/gh/marten-seemann/qpack) + +This is a minimal QPACK implementation in Go. It is minimal in the sense that it doesn't use the dynamic table at all, but just the static table and (Huffman encoded) string literals. Wherever possible, it reuses code from the [HPACK implementation in the Go standard library](https://github.com/golang/net/tree/master/http2/hpack). + +It should be able to interoperate with other QPACK implemetations (both encoders and decoders), however it won't achieve a high compression efficiency. + +## Running the interop tests + +Install the [QPACK interop files](https://github.com/qpackers/qifs/) by running +```bash +git submodule update --init --recursive +``` + +Then run the tests: +```bash +ginkgo -r integrationtests +``` diff --git a/vendor/github.com/marten-seemann/qpack/decoder.go b/vendor/github.com/marten-seemann/qpack/decoder.go new file mode 100644 index 00000000..c9001941 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/decoder.go @@ -0,0 +1,271 @@ +package qpack + +import ( + "bytes" + "errors" + "fmt" + "sync" + + "golang.org/x/net/http2/hpack" +) + +// A decodingError is something the spec defines as a decoding error. +type decodingError struct { + err error +} + +func (de decodingError) Error() string { + return fmt.Sprintf("decoding error: %v", de.err) +} + +// An invalidIndexError is returned when an encoder references a table +// entry before the static table or after the end of the dynamic table. +type invalidIndexError int + +func (e invalidIndexError) Error() string { + return fmt.Sprintf("invalid indexed representation index %d", int(e)) +} + +var errNoDynamicTable = decodingError{errors.New("no dynamic table")} + +// errNeedMore is an internal sentinel error value that means the +// buffer is truncated and we need to read more data before we can +// continue parsing. +var errNeedMore = errors.New("need more data") + +// A Decoder is the decoding context for incremental processing of +// header blocks. +type Decoder struct { + mutex sync.Mutex + + emitFunc func(f HeaderField) + + readRequiredInsertCount bool + readDeltaBase bool + + // buf is the unparsed buffer. It's only written to + // saveBuf if it was truncated in the middle of a header + // block. Because it's usually not owned, we can only + // process it under Write. + buf []byte // not owned; only valid during Write + + // saveBuf is previous data passed to Write which we weren't able + // to fully parse before. Unlike buf, we own this data. + saveBuf bytes.Buffer +} + +// NewDecoder returns a new decoder +// The emitFunc will be called for each valid field parsed, +// in the same goroutine as calls to Write, before Write returns. +func NewDecoder(emitFunc func(f HeaderField)) *Decoder { + return &Decoder{emitFunc: emitFunc} +} + +func (d *Decoder) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + d.mutex.Lock() + n, err := d.writeLocked(p) + d.mutex.Unlock() + return n, err +} + +func (d *Decoder) writeLocked(p []byte) (int, error) { + // Only copy the data if we have to. Optimistically assume + // that p will contain a complete header block. + if d.saveBuf.Len() == 0 { + d.buf = p + } else { + d.saveBuf.Write(p) + d.buf = d.saveBuf.Bytes() + d.saveBuf.Reset() + } + + if err := d.decode(); err != nil { + if err != errNeedMore { + return 0, err + } + // TODO: limit the size of the buffer + d.saveBuf.Write(d.buf) + } + return len(p), nil +} + +// DecodeFull decodes an entire block. +func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) { + if len(p) == 0 { + return []HeaderField{}, nil + } + + d.mutex.Lock() + defer d.mutex.Unlock() + + saveFunc := d.emitFunc + defer func() { d.emitFunc = saveFunc }() + + var hf []HeaderField + d.emitFunc = func(f HeaderField) { hf = append(hf, f) } + if _, err := d.writeLocked(p); err != nil { + return nil, err + } + if err := d.Close(); err != nil { + return nil, err + } + return hf, nil +} + +// Close declares that the decoding is complete and resets the Decoder +// to be reused again for a new header block. If there is any remaining +// data in the decoder's buffer, Close returns an error. +func (d *Decoder) Close() error { + if d.saveBuf.Len() > 0 { + d.saveBuf.Reset() + return decodingError{errors.New("truncated headers")} + } + d.readRequiredInsertCount = false + d.readDeltaBase = false + return nil +} + +func (d *Decoder) decode() error { + if !d.readRequiredInsertCount { + requiredInsertCount, rest, err := readVarInt(8, d.buf) + if err != nil { + return err + } + d.readRequiredInsertCount = true + if requiredInsertCount != 0 { + return decodingError{errors.New("expected Required Insert Count to be zero")} + } + d.buf = rest + } + if !d.readDeltaBase { + base, rest, err := readVarInt(7, d.buf) + if err != nil { + return err + } + d.readDeltaBase = true + if base != 0 { + return decodingError{errors.New("expected Base to be zero")} + } + d.buf = rest + } + if len(d.buf) == 0 { + return errNeedMore + } + + for len(d.buf) > 0 { + b := d.buf[0] + var err error + switch { + case b&0x80 > 0: // 1xxxxxxx + err = d.parseIndexedHeaderField() + case b&0xc0 == 0x40: // 01xxxxxx + err = d.parseLiteralHeaderField() + case b&0xe0 == 0x20: // 001xxxxx + err = d.parseLiteralHeaderFieldWithoutNameReference() + default: + err = fmt.Errorf("unexpected type byte: %#x", b) + } + if err != nil { + return err + } + } + return nil +} + +func (d *Decoder) parseIndexedHeaderField() error { + buf := d.buf + if buf[0]&0x40 == 0 { + return errNoDynamicTable + } + index, buf, err := readVarInt(6, buf) + if err != nil { + return err + } + hf, ok := d.at(index) + if !ok { + return decodingError{invalidIndexError(index)} + } + d.emitFunc(hf) + d.buf = buf + return nil +} + +func (d *Decoder) parseLiteralHeaderField() error { + buf := d.buf + if buf[0]&0x20 > 0 || buf[0]&0x10 == 0 { + return errNoDynamicTable + } + index, buf, err := readVarInt(4, buf) + if err != nil { + return err + } + hf, ok := d.at(index) + if !ok { + return decodingError{invalidIndexError(index)} + } + if len(buf) == 0 { + return errNeedMore + } + usesHuffman := buf[0]&0x80 > 0 + val, buf, err := d.readString(buf, 7, usesHuffman) + if err != nil { + return err + } + hf.Value = val + d.emitFunc(hf) + d.buf = buf + return nil +} + +func (d *Decoder) parseLiteralHeaderFieldWithoutNameReference() error { + buf := d.buf + usesHuffmanForName := buf[0]&0x8 > 0 + name, buf, err := d.readString(buf, 3, usesHuffmanForName) + if err != nil { + return err + } + if len(buf) == 0 { + return errNeedMore + } + usesHuffmanForVal := buf[0]&0x80 > 0 + val, buf, err := d.readString(buf, 7, usesHuffmanForVal) + if err != nil { + return err + } + d.emitFunc(HeaderField{Name: name, Value: val}) + d.buf = buf + return nil +} + +func (d *Decoder) readString(buf []byte, n uint8, usesHuffman bool) (string, []byte, error) { + l, buf, err := readVarInt(n, buf) + if err != nil { + return "", nil, err + } + if uint64(len(buf)) < l { + return "", nil, errNeedMore + } + var val string + if usesHuffman { + var err error + val, err = hpack.HuffmanDecodeToString(buf[:l]) + if err != nil { + return "", nil, err + } + } else { + val = string(buf[:l]) + } + buf = buf[l:] + return val, buf, nil +} + +func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) { + if i >= uint64(len(staticTableEntries)) { + return + } + return staticTableEntries[i], true +} diff --git a/vendor/github.com/marten-seemann/qpack/encoder.go b/vendor/github.com/marten-seemann/qpack/encoder.go new file mode 100644 index 00000000..13e0ad2c --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/encoder.go @@ -0,0 +1,95 @@ +package qpack + +import ( + "io" + + "golang.org/x/net/http2/hpack" +) + +// An Encoder performs QPACK encoding. +type Encoder struct { + wrotePrefix bool + + w io.Writer + buf []byte +} + +// NewEncoder returns a new Encoder which performs QPACK encoding. An +// encoded data is written to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{w: w} +} + +// WriteField encodes f into a single Write to e's underlying Writer. +// This function may also produce bytes for the Header Block Prefix +// if necessary. If produced, it is done before encoding f. +func (e *Encoder) WriteField(f HeaderField) error { + // write the Header Block Prefix + if !e.wrotePrefix { + e.buf = appendVarInt(e.buf, 8, 0) + e.buf = appendVarInt(e.buf, 7, 0) + e.wrotePrefix = true + } + + idxAndVals, nameFound := encoderMap[f.Name] + if nameFound { + if idxAndVals.values == nil { + if len(f.Value) == 0 { + e.writeIndexedField(idxAndVals.idx) + } else { + e.writeLiteralFieldWithNameReference(&f, idxAndVals.idx) + } + } else { + valIdx, valueFound := idxAndVals.values[f.Value] + if valueFound { + e.writeIndexedField(valIdx) + } else { + e.writeLiteralFieldWithNameReference(&f, idxAndVals.idx) + } + } + } else { + e.writeLiteralFieldWithoutNameReference(f) + } + + e.w.Write(e.buf) + e.buf = e.buf[:0] + return nil +} + +// Close declares that the encoding is complete and resets the Encoder +// to be reused again for a new header block. +func (e *Encoder) Close() error { + e.wrotePrefix = false + return nil +} + +func (e *Encoder) writeLiteralFieldWithoutNameReference(f HeaderField) { + offset := len(e.buf) + e.buf = appendVarInt(e.buf, 3, hpack.HuffmanEncodeLength(f.Name)) + e.buf[offset] ^= 0x20 ^ 0x8 + e.buf = hpack.AppendHuffmanString(e.buf, f.Name) + offset = len(e.buf) + e.buf = appendVarInt(e.buf, 7, hpack.HuffmanEncodeLength(f.Value)) + e.buf[offset] ^= 0x80 + e.buf = hpack.AppendHuffmanString(e.buf, f.Value) +} + +// Encodes a header field whose name is present in one of the tables. +func (e *Encoder) writeLiteralFieldWithNameReference(f *HeaderField, id uint8) { + offset := len(e.buf) + e.buf = appendVarInt(e.buf, 4, uint64(id)) + // Set the 01NTxxxx pattern, forcing N to 0 and T to 1 + e.buf[offset] ^= 0x50 + offset = len(e.buf) + e.buf = appendVarInt(e.buf, 7, hpack.HuffmanEncodeLength(f.Value)) + e.buf[offset] ^= 0x80 + e.buf = hpack.AppendHuffmanString(e.buf, f.Value) +} + +// Encodes an indexed field, meaning it's entirely defined in one of the tables. +func (e *Encoder) writeIndexedField(id uint8) { + offset := len(e.buf) + e.buf = appendVarInt(e.buf, 6, uint64(id)) + // Set the 1Txxxxxx pattern, forcing T to 1 + e.buf[offset] ^= 0xc0 +} diff --git a/vendor/github.com/marten-seemann/qpack/header_field.go b/vendor/github.com/marten-seemann/qpack/header_field.go new file mode 100644 index 00000000..4c043a99 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/header_field.go @@ -0,0 +1,16 @@ +package qpack + +// A HeaderField is a name-value pair. Both the name and value are +// treated as opaque sequences of octets. +type HeaderField struct { + Name string + Value string +} + +// IsPseudo reports whether the header field is an HTTP3 pseudo header. +// That is, it reports whether it starts with a colon. +// It is not otherwise guaranteed to be a valid pseudo header field, +// though. +func (hf HeaderField) IsPseudo() bool { + return len(hf.Name) != 0 && hf.Name[0] == ':' +} diff --git a/vendor/github.com/marten-seemann/qpack/static_table.go b/vendor/github.com/marten-seemann/qpack/static_table.go new file mode 100644 index 00000000..930e83c7 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/static_table.go @@ -0,0 +1,254 @@ +package qpack + +var staticTableEntries = [...]HeaderField{ + {Name: ":authority"}, + {Name: ":path", Value: "/"}, + {Name: "age", Value: "0"}, + {Name: "content-disposition"}, + {Name: "content-length", Value: "0"}, + {Name: "cookie"}, + {Name: "date"}, + {Name: "etag"}, + {Name: "if-modified-since"}, + {Name: "if-none-match"}, + {Name: "last-modified"}, + {Name: "link"}, + {Name: "location"}, + {Name: "referer"}, + {Name: "set-cookie"}, + {Name: ":method", Value: "CONNECT"}, + {Name: ":method", Value: "DELETE"}, + {Name: ":method", Value: "GET"}, + {Name: ":method", Value: "HEAD"}, + {Name: ":method", Value: "OPTIONS"}, + {Name: ":method", Value: "POST"}, + {Name: ":method", Value: "PUT"}, + {Name: ":scheme", Value: "http"}, + {Name: ":scheme", Value: "https"}, + {Name: ":status", Value: "103"}, + {Name: ":status", Value: "200"}, + {Name: ":status", Value: "304"}, + {Name: ":status", Value: "404"}, + {Name: ":status", Value: "503"}, + {Name: "accept", Value: "*/*"}, + {Name: "accept", Value: "application/dns-message"}, + {Name: "accept-encoding", Value: "gzip, deflate, br"}, + {Name: "accept-ranges", Value: "bytes"}, + {Name: "access-control-allow-headers", Value: "cache-control"}, + {Name: "access-control-allow-headers", Value: "content-type"}, + {Name: "access-control-allow-origin", Value: "*"}, + {Name: "cache-control", Value: "max-age=0"}, + {Name: "cache-control", Value: "max-age=2592000"}, + {Name: "cache-control", Value: "max-age=604800"}, + {Name: "cache-control", Value: "no-cache"}, + {Name: "cache-control", Value: "no-store"}, + {Name: "cache-control", Value: "public, max-age=31536000"}, + {Name: "content-encoding", Value: "br"}, + {Name: "content-encoding", Value: "gzip"}, + {Name: "content-type", Value: "application/dns-message"}, + {Name: "content-type", Value: "application/javascript"}, + {Name: "content-type", Value: "application/json"}, + {Name: "content-type", Value: "application/x-www-form-urlencoded"}, + {Name: "content-type", Value: "image/gif"}, + {Name: "content-type", Value: "image/jpeg"}, + {Name: "content-type", Value: "image/png"}, + {Name: "content-type", Value: "text/css"}, + {Name: "content-type", Value: "text/html; charset=utf-8"}, + {Name: "content-type", Value: "text/plain"}, + {Name: "content-type", Value: "text/plain;charset=utf-8"}, + {Name: "range", Value: "bytes=0-"}, + {Name: "strict-transport-security", Value: "max-age=31536000"}, + {Name: "strict-transport-security", Value: "max-age=31536000; includesubdomains"}, + {Name: "strict-transport-security", Value: "max-age=31536000; includesubdomains; preload"}, + {Name: "vary", Value: "accept-encoding"}, + {Name: "vary", Value: "origin"}, + {Name: "x-content-type-options", Value: "nosniff"}, + {Name: "x-xss-protection", Value: "1; mode=block"}, + {Name: ":status", Value: "100"}, + {Name: ":status", Value: "204"}, + {Name: ":status", Value: "206"}, + {Name: ":status", Value: "302"}, + {Name: ":status", Value: "400"}, + {Name: ":status", Value: "403"}, + {Name: ":status", Value: "421"}, + {Name: ":status", Value: "425"}, + {Name: ":status", Value: "500"}, + {Name: "accept-language"}, + {Name: "access-control-allow-credentials", Value: "FALSE"}, + {Name: "access-control-allow-credentials", Value: "TRUE"}, + {Name: "access-control-allow-headers", Value: "*"}, + {Name: "access-control-allow-methods", Value: "get"}, + {Name: "access-control-allow-methods", Value: "get, post, options"}, + {Name: "access-control-allow-methods", Value: "options"}, + {Name: "access-control-expose-headers", Value: "content-length"}, + {Name: "access-control-request-headers", Value: "content-type"}, + {Name: "access-control-request-method", Value: "get"}, + {Name: "access-control-request-method", Value: "post"}, + {Name: "alt-svc", Value: "clear"}, + {Name: "authorization"}, + {Name: "content-security-policy", Value: "script-src 'none'; object-src 'none'; base-uri 'none'"}, + {Name: "early-data", Value: "1"}, + {Name: "expect-ct"}, + {Name: "forwarded"}, + {Name: "if-range"}, + {Name: "origin"}, + {Name: "purpose", Value: "prefetch"}, + {Name: "server"}, + {Name: "timing-allow-origin", Value: "*"}, + {Name: "upgrade-insecure-requests", Value: "1"}, + {Name: "user-agent"}, + {Name: "x-forwarded-for"}, + {Name: "x-frame-options", Value: "deny"}, + {Name: "x-frame-options", Value: "sameorigin"}, +} + +// Only needed for tests. +// use go:linkname to retrieve the static table. +//nolint:deadcode,unused +func getStaticTable() []HeaderField { + return staticTableEntries[:] +} + +type indexAndValues struct { + idx uint8 + values map[string]uint8 +} + +// A map of the header names from the static table to their index in the table. +// This is used by the encoder to quickly find if a header is in the static table +// and what value should be used to encode it. +// There's a second level of mapping for the headers that have some predefined +// values in the static table. +var encoderMap = map[string]indexAndValues{ + ":authority": {0, nil}, + ":path": {1, map[string]uint8{"/": 1}}, + "age": {2, map[string]uint8{"0": 2}}, + "content-disposition": {3, nil}, + "content-length": {4, map[string]uint8{"0": 4}}, + "cookie": {5, nil}, + "date": {6, nil}, + "etag": {7, nil}, + "if-modified-since": {8, nil}, + "if-none-match": {9, nil}, + "last-modified": {10, nil}, + "link": {11, nil}, + "location": {12, nil}, + "referer": {13, nil}, + "set-cookie": {14, nil}, + ":method": {15, map[string]uint8{ + "CONNECT": 15, + "DELETE": 16, + "GET": 17, + "HEAD": 18, + "OPTIONS": 19, + "POST": 20, + "PUT": 21, + }}, + ":scheme": {22, map[string]uint8{ + "http": 22, + "https": 23, + }}, + ":status": {24, map[string]uint8{ + "103": 24, + "200": 25, + "304": 26, + "404": 27, + "503": 28, + "100": 63, + "204": 64, + "206": 65, + "302": 66, + "400": 67, + "403": 68, + "421": 69, + "425": 70, + "500": 71, + }}, + "accept": {29, map[string]uint8{ + "*/*": 29, + "application/dns-message": 30, + }}, + "accept-encoding": {31, map[string]uint8{"gzip, deflate, br": 31}}, + "accept-ranges": {32, map[string]uint8{"bytes": 32}}, + "access-control-allow-headers": {33, map[string]uint8{ + "cache-control": 33, + "content-type": 34, + "*": 75, + }}, + "access-control-allow-origin": {35, map[string]uint8{"*": 35}}, + "cache-control": {36, map[string]uint8{ + "max-age=0": 36, + "max-age=2592000": 37, + "max-age=604800": 38, + "no-cache": 39, + "no-store": 40, + "public, max-age=31536000": 41, + }}, + "content-encoding": {42, map[string]uint8{ + "br": 42, + "gzip": 43, + }}, + "content-type": {44, map[string]uint8{ + "application/dns-message": 44, + "application/javascript": 45, + "application/json": 46, + "application/x-www-form-urlencoded": 47, + "image/gif": 48, + "image/jpeg": 49, + "image/png": 50, + "text/css": 51, + "text/html; charset=utf-8": 52, + "text/plain": 53, + "text/plain;charset=utf-8": 54, + }}, + "range": {55, map[string]uint8{"bytes=0-": 55}}, + "strict-transport-security": {56, map[string]uint8{ + "max-age=31536000": 56, + "max-age=31536000; includesubdomains": 57, + "max-age=31536000; includesubdomains; preload": 58, + }}, + "vary": {59, map[string]uint8{ + "accept-encoding": 59, + "origin": 60, + }}, + "x-content-type-options": {61, map[string]uint8{"nosniff": 61}}, + "x-xss-protection": {62, map[string]uint8{"1; mode=block": 62}}, + // ":status" is duplicated and takes index 63 to 71 + "accept-language": {72, nil}, + "access-control-allow-credentials": {73, map[string]uint8{ + "FALSE": 73, + "TRUE": 74, + }}, + // "access-control-allow-headers" is duplicated and takes index 75 + "access-control-allow-methods": {76, map[string]uint8{ + "get": 76, + "get, post, options": 77, + "options": 78, + }}, + "access-control-expose-headers": {79, map[string]uint8{"content-length": 79}}, + "access-control-request-headers": {80, map[string]uint8{"content-type": 80}}, + "access-control-request-method": {81, map[string]uint8{ + "get": 81, + "post": 82, + }}, + "alt-svc": {83, map[string]uint8{"clear": 83}}, + "authorization": {84, nil}, + "content-security-policy": {85, map[string]uint8{ + "script-src 'none'; object-src 'none'; base-uri 'none'": 85, + }}, + "early-data": {86, map[string]uint8{"1": 86}}, + "expect-ct": {87, nil}, + "forwarded": {88, nil}, + "if-range": {89, nil}, + "origin": {90, nil}, + "purpose": {91, map[string]uint8{"prefetch": 91}}, + "server": {92, nil}, + "timing-allow-origin": {93, map[string]uint8{"*": 93}}, + "upgrade-insecure-requests": {94, map[string]uint8{"1": 94}}, + "user-agent": {95, nil}, + "x-forwarded-for": {96, nil}, + "x-frame-options": {97, map[string]uint8{ + "deny": 97, + "sameorigin": 98, + }}, +} diff --git a/vendor/github.com/marten-seemann/qpack/varint.go b/vendor/github.com/marten-seemann/qpack/varint.go new file mode 100644 index 00000000..28d71122 --- /dev/null +++ b/vendor/github.com/marten-seemann/qpack/varint.go @@ -0,0 +1,66 @@ +package qpack + +// copied from the Go standard library HPACK implementation + +import "errors" + +var errVarintOverflow = errors.New("varint integer overflow") + +// appendVarInt appends i, as encoded in variable integer form using n +// bit prefix, to dst and returns the extended buffer. +// +// See +// http://http2.github.io/http2-spec/compression.html#integer.representation +func appendVarInt(dst []byte, n byte, i uint64) []byte { + k := uint64((1 << n) - 1) + if i < k { + return append(dst, byte(i)) + } + dst = append(dst, byte(k)) + i -= k + for ; i >= 128; i >>= 7 { + dst = append(dst, byte(0x80|(i&0x7f))) + } + return append(dst, byte(i)) +} + +// readVarInt reads an unsigned variable length integer off the +// beginning of p. n is the parameter as described in +// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1. +// +// n must always be between 1 and 8. +// +// The returned remain buffer is either a smaller suffix of p, or err != nil. +// The error is errNeedMore if p doesn't contain a complete integer. +func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) { + if n < 1 || n > 8 { + panic("bad n") + } + if len(p) == 0 { + return 0, p, errNeedMore + } + i = uint64(p[0]) + if n < 8 { + i &= (1 << uint64(n)) - 1 + } + if i < (1< 0 { + b := p[0] + p = p[1:] + i += uint64(b&127) << m + if b&128 == 0 { + return i, p, nil + } + m += 7 + if m >= 63 { // TODO: proper overflow check. making this up. + return 0, origP, errVarintOverflow + } + } + return 0, origP, errNeedMore +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/LICENSE b/vendor/github.com/marten-seemann/qtls-go1-16/LICENSE new file mode 100644 index 00000000..6a66aea5 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/README.md b/vendor/github.com/marten-seemann/qtls-go1-16/README.md new file mode 100644 index 00000000..0d318abe --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/README.md @@ -0,0 +1,6 @@ +# qtls + +[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](https://godoc.org/github.com/marten-seemann/qtls) +[![CircleCI Build Status](https://img.shields.io/circleci/project/github/marten-seemann/qtls.svg?style=flat-square&label=CircleCI+build)](https://circleci.com/gh/marten-seemann/qtls) + +This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/alert.go b/vendor/github.com/marten-seemann/qtls-go1-16/alert.go new file mode 100644 index 00000000..3feac79b --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/alert.go @@ -0,0 +1,102 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import "strconv" + +type alert uint8 + +// Alert is a TLS alert +type Alert = alert + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertExportRestriction alert = 60 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertInappropriateFallback alert = 86 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 + alertMissingExtension alert = 109 + alertUnsupportedExtension alert = 110 + alertCertificateUnobtainable alert = 111 + alertUnrecognizedName alert = 112 + alertBadCertificateStatusResponse alert = 113 + alertBadCertificateHashValue alert = 114 + alertUnknownPSKIdentity alert = 115 + alertCertificateRequired alert = 116 + alertNoApplicationProtocol alert = 120 +) + +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertExportRestriction: "export restriction", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertInappropriateFallback: "inappropriate fallback", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", + alertMissingExtension: "missing extension", + alertUnsupportedExtension: "unsupported extension", + alertCertificateUnobtainable: "certificate unobtainable", + alertUnrecognizedName: "unrecognized name", + alertBadCertificateStatusResponse: "bad certificate status response", + alertBadCertificateHashValue: "bad certificate hash value", + alertUnknownPSKIdentity: "unknown PSK identity", + alertCertificateRequired: "certificate required", + alertNoApplicationProtocol: "no application protocol", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return "tls: " + s + } + return "tls: alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/auth.go b/vendor/github.com/marten-seemann/qtls-go1-16/auth.go new file mode 100644 index 00000000..1ef675fd --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/auth.go @@ -0,0 +1,289 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "errors" + "fmt" + "hash" + "io" +) + +// verifyHandshakeSignature verifies a signature against pre-hashed +// (if required) handshake contents. +func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) + } + if !ecdsa.VerifyASN1(pubKey, signed, sig) { + return errors.New("ECDSA verification failure") + } + case signatureEd25519: + pubKey, ok := pubkey.(ed25519.PublicKey) + if !ok { + return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) + } + if !ed25519.Verify(pubKey, signed, sig) { + return errors.New("Ed25519 verification failure") + } + case signaturePKCS1v15: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { + return err + } + case signatureRSAPSS: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { + return err + } + default: + return errors.New("internal error: unknown signature type") + } + return nil +} + +const ( + serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" + clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" +) + +var signaturePadding = []byte{ + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, +} + +// signedMessage returns the pre-hashed (if necessary) message to be signed by +// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. +func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { + if sigHash == directSigning { + b := &bytes.Buffer{} + b.Write(signaturePadding) + io.WriteString(b, context) + b.Write(transcript.Sum(nil)) + return b.Bytes() + } + h := sigHash.New() + h.Write(signaturePadding) + io.WriteString(h, context) + h.Write(transcript.Sum(nil)) + return h.Sum(nil) +} + +// typeAndHashFromSignatureScheme returns the corresponding signature type and +// crypto.Hash for a given TLS SignatureScheme. +func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { + switch signatureAlgorithm { + case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: + sigType = signaturePKCS1v15 + case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: + sigType = signatureRSAPSS + case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: + sigType = signatureECDSA + case Ed25519: + sigType = signatureEd25519 + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + switch signatureAlgorithm { + case PKCS1WithSHA1, ECDSAWithSHA1: + hash = crypto.SHA1 + case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: + hash = crypto.SHA256 + case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: + hash = crypto.SHA384 + case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: + hash = crypto.SHA512 + case Ed25519: + hash = directSigning + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + return sigType, hash, nil +} + +// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for +// a given public key used with TLS 1.0 and 1.1, before the introduction of +// signature algorithm negotiation. +func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { + switch pub.(type) { + case *rsa.PublicKey: + return signaturePKCS1v15, crypto.MD5SHA1, nil + case *ecdsa.PublicKey: + return signatureECDSA, crypto.SHA1, nil + case ed25519.PublicKey: + // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, + // but it requires holding on to a handshake transcript to do a + // full signature, and not even OpenSSL bothers with the + // complexity, so we can't even test it properly. + return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") + default: + return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) + } +} + +var rsaSignatureSchemes = []struct { + scheme SignatureScheme + minModulusBytes int + maxVersion uint16 +}{ + // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires + // emLen >= hLen + sLen + 2 + {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, + // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires + // emLen >= len(prefix) + hLen + 11 + // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. + {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, + {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, + {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, + {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, +} + +// signatureSchemesForCertificate returns the list of supported SignatureSchemes +// for a given certificate, based on the public key and the protocol version, +// and optionally filtered by its explicit SupportedSignatureAlgorithms. +// +// This function must be kept in sync with supportedSignatureAlgorithms. +func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil + } + + var sigAlgs []SignatureScheme + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + if version != VersionTLS13 { + // In TLS 1.2 and earlier, ECDSA algorithms are not + // constrained to a single curve. + sigAlgs = []SignatureScheme{ + ECDSAWithP256AndSHA256, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + ECDSAWithSHA1, + } + break + } + switch pub.Curve { + case elliptic.P256(): + sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} + case elliptic.P384(): + sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} + case elliptic.P521(): + sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} + default: + return nil + } + case *rsa.PublicKey: + size := pub.Size() + sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) + for _, candidate := range rsaSignatureSchemes { + if size >= candidate.minModulusBytes && version <= candidate.maxVersion { + sigAlgs = append(sigAlgs, candidate.scheme) + } + } + case ed25519.PublicKey: + sigAlgs = []SignatureScheme{Ed25519} + default: + return nil + } + + if cert.SupportedSignatureAlgorithms != nil { + var filteredSigAlgs []SignatureScheme + for _, sigAlg := range sigAlgs { + if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { + filteredSigAlgs = append(filteredSigAlgs, sigAlg) + } + } + return filteredSigAlgs + } + return sigAlgs +} + +// selectSignatureScheme picks a SignatureScheme from the peer's preference list +// that works with the selected certificate. It's only called for protocol +// versions that support signature algorithms, so TLS 1.2 and 1.3. +func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { + supportedAlgs := signatureSchemesForCertificate(vers, c) + if len(supportedAlgs) == 0 { + return 0, unsupportedCertificateError(c) + } + if len(peerAlgs) == 0 && vers == VersionTLS12 { + // For TLS 1.2, if the client didn't send signature_algorithms then we + // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. + peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} + } + // Pick signature scheme in the peer's preference order, as our + // preference order is not configurable. + for _, preferredAlg := range peerAlgs { + if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { + return preferredAlg, nil + } + } + return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") +} + +// unsupportedCertificateError returns a helpful error for certificates with +// an unsupported private key. +func unsupportedCertificateError(cert *Certificate) error { + switch cert.PrivateKey.(type) { + case rsa.PrivateKey, ecdsa.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", + cert.PrivateKey, cert.PrivateKey) + case *ed25519.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") + } + + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", + cert.PrivateKey) + } + + switch pub := signer.Public().(type) { + case *ecdsa.PublicKey: + switch pub.Curve { + case elliptic.P256(): + case elliptic.P384(): + case elliptic.P521(): + default: + return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) + } + case *rsa.PublicKey: + return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") + case ed25519.PublicKey: + default: + return fmt.Errorf("tls: unsupported certificate key (%T)", pub) + } + + if cert.SupportedSignatureAlgorithms != nil { + return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") + } + + return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/cipher_suites.go b/vendor/github.com/marten-seemann/qtls-go1-16/cipher_suites.go new file mode 100644 index 00000000..78f80107 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/cipher_suites.go @@ -0,0 +1,532 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/hmac" + "crypto/rc4" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "fmt" + "hash" + + "golang.org/x/crypto/chacha20poly1305" +) + +// CipherSuite is a TLS cipher suite. Note that most functions in this package +// accept and expose cipher suite IDs instead of this type. +type CipherSuite struct { + ID uint16 + Name string + + // Supported versions is the list of TLS protocol versions that can + // negotiate this cipher suite. + SupportedVersions []uint16 + + // Insecure is true if the cipher suite has known security issues + // due to its primitives, design, or implementation. + Insecure bool +} + +var ( + supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} + supportedOnlyTLS12 = []uint16{VersionTLS12} + supportedOnlyTLS13 = []uint16{VersionTLS13} +) + +// CipherSuites returns a list of cipher suites currently implemented by this +// package, excluding those with security issues, which are returned by +// InsecureCipherSuites. +// +// The list is sorted by ID. Note that the default cipher suites selected by +// this package might depend on logic that can't be captured by a static list. +func CipherSuites() []*CipherSuite { + return []*CipherSuite{ + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + + {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, + {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, + {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, + + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + } +} + +// InsecureCipherSuites returns a list of cipher suites currently implemented by +// this package and which have security issues. +// +// Most applications should not use the cipher suites in this list, and should +// only use those returned by CipherSuites. +func InsecureCipherSuites() []*CipherSuite { + // RC4 suites are broken because RC4 is. + // CBC-SHA256 suites have no Lucky13 countermeasures. + return []*CipherSuite{ + {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + } +} + +// CipherSuiteName returns the standard name for the passed cipher suite ID +// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation +// of the ID value if the cipher suite is not implemented by this package. +func CipherSuiteName(id uint16) string { + for _, c := range CipherSuites() { + if c.ID == id { + return c.Name + } + } + for _, c := range InsecureCipherSuites() { + if c.ID == id { + return c.Name + } + } + return fmt.Sprintf("0x%04X", id) +} + +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error + generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) +} + +const ( + // suiteECDHE indicates that the cipher suite involves elliptic curve + // Diffie-Hellman. This means that it should only be selected when the + // client indicates that it supports ECC with a curve and point format + // that we're happy with. + suiteECDHE = 1 << iota + // suiteECSign indicates that the cipher suite involves an ECDSA or + // EdDSA signature and therefore may only be selected when the server's + // certificate is ECDSA or EdDSA. If this is not set then the cipher suite + // is RSA based. + suiteECSign + // suiteTLS12 indicates that the cipher suite should only be advertised + // and accepted when using TLS 1.2. + suiteTLS12 + // suiteSHA384 indicates that the cipher suite uses SHA384 as the + // handshake hash. + suiteSHA384 + // suiteDefaultOff indicates that this cipher suite is not included by + // default. + suiteDefaultOff +) + +// A cipherSuite is a specific combination of key agreement, cipher and MAC function. +type cipherSuite struct { + id uint16 + // the lengths, in bytes, of the key material needed for each component. + keyLen int + macLen int + ivLen int + ka func(version uint16) keyAgreement + // flags is a bitmask of the suite* values, above. + flags int + cipher func(key, iv []byte, isRead bool) interface{} + mac func(key []byte) hash.Hash + aead func(key, fixedNonce []byte) aead +} + +var cipherSuites = []*cipherSuite{ + // Ciphersuite order is chosen so that ECDHE comes before plain RSA and + // AEADs are the top preference. + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, + + // RC4-based cipher suites are disabled by default. + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteDefaultOff, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteDefaultOff, cipherRC4, macSHA1, nil}, +} + +// selectCipherSuite returns the first cipher suite from ids which is also in +// supportedIDs and passes the ok filter. +func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { + for _, id := range ids { + candidate := cipherSuiteByID(id) + if candidate == nil || !ok(candidate) { + continue + } + + for _, suppID := range supportedIDs { + if id == suppID { + return candidate + } + } + } + return nil +} + +// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash +// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. +type cipherSuiteTLS13 struct { + id uint16 + keyLen int + aead func(key, fixedNonce []byte) aead + hash crypto.Hash +} + +type CipherSuiteTLS13 struct { + ID uint16 + KeyLen int + Hash crypto.Hash + AEAD func(key, fixedNonce []byte) cipher.AEAD +} + +func (c *CipherSuiteTLS13) IVLen() int { + return aeadNonceLength +} + +var cipherSuitesTLS13 = []*cipherSuiteTLS13{ + {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, + {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, + {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, +} + +func cipherRC4(key, iv []byte, isRead bool) interface{} { + cipher, _ := rc4.NewCipher(key) + return cipher +} + +func cipher3DES(key, iv []byte, isRead bool) interface{} { + block, _ := des.NewTripleDESCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +func cipherAES(key, iv []byte, isRead bool) interface{} { + block, _ := aes.NewCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +// macSHA1 returns a SHA-1 based constant time MAC. +func macSHA1(key []byte) hash.Hash { + return hmac.New(newConstantTimeHash(sha1.New), key) +} + +// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and +// is currently only used in disabled-by-default cipher suites. +func macSHA256(key []byte) hash.Hash { + return hmac.New(sha256.New, key) +} + +type aead interface { + cipher.AEAD + + // explicitNonceLen returns the number of bytes of explicit nonce + // included in each record. This is eight for older AEADs and + // zero for modern ones. + explicitNonceLen() int +} + +const ( + aeadNonceLength = 12 + noncePrefixLength = 4 +) + +// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +// each call. +type prefixNonceAEAD struct { + // nonce contains the fixed part of the nonce in the first four bytes. + nonce [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } +func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } + +func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + copy(f.nonce[4:], nonce) + return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) +} + +func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + copy(f.nonce[4:], nonce) + return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) +} + +// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// before each call. +type xorNonceAEAD struct { + nonceMask [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result, err +} + +func aeadAESGCM(key, noncePrefix []byte) aead { + if len(noncePrefix) != noncePrefixLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &prefixNonceAEAD{aead: aead} + copy(ret.nonce[:], noncePrefix) + return ret +} + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return aeadAESGCMTLS13(key, fixedNonce) +} + +func aeadAESGCMTLS13(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +func aeadChaCha20Poly1305(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + +// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. +func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { + h.Reset() + h.Write(seq) + h.Write(header) + h.Write(data) + res := h.Sum(out) + if extra != nil { + h.Write(extra) + } + return res +} + +func rsaKA(version uint16) keyAgreement { + return rsaKeyAgreement{} +} + +func ecdheECDSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: false, + version: version, + } +} + +func ecdheRSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: true, + version: version, + } +} + +// mutualCipherSuite returns a cipherSuite given a list of supported +// ciphersuites and the id requested by the peer. +func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { + for _, id := range have { + if id == want { + return cipherSuiteByID(id) + } + } + return nil +} + +func cipherSuiteByID(id uint16) *cipherSuite { + for _, cipherSuite := range cipherSuites { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { + for _, id := range have { + if id == want { + return cipherSuiteTLS13ByID(id) + } + } + return nil +} + +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { + for _, cipherSuite := range cipherSuitesTLS13 { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +// A list of cipher suite IDs that are, or have been, implemented by this +// package. +// +// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml +const ( + // TLS 1.0 - 1.2 cipher suites. + TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 + + // TLS 1.3 cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = 0x1301 + TLS_AES_256_GCM_SHA384 uint16 = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 + + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator + // that the client is doing version fallback. See RFC 7507. + TLS_FALLBACK_SCSV uint16 = 0x5600 + + // Legacy names for the corresponding cipher suites with the correct _SHA256 + // suffix, retained for backward compatibility. + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/common.go b/vendor/github.com/marten-seemann/qtls-go1-16/common.go new file mode 100644 index 00000000..266f93fe --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/common.go @@ -0,0 +1,1576 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "container/list" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "sort" + "strings" + "sync" + "time" +) + +const ( + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 + VersionTLS13 = 0x0304 + + // Deprecated: SSLv3 is cryptographically broken, and is no longer + // supported by this package. See golang.org/issue/32716. + VersionSSL30 = 0x0300 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxUselessRecords = 16 // maximum number of consecutive non-advancing records +) + +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 + typeNextProtocol uint8 = 67 // Not IANA assigned + typeMessageHash uint8 = 254 // synthetic message +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +type Extension struct { + Type uint16 + Data []byte +} + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionRenegotiationInfo uint16 = 0xff01 +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +type EncryptionLevel uint8 + +const ( + EncryptionHandshake EncryptionLevel = iota + Encryption0RTT + EncryptionApplication +) + +// CurveID is a tls.CurveID +type CurveID = tls.CurveID + +const ( + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + X25519 CurveID = 29 +) + +// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. +type keyShare struct { + group CurveID + data []byte +} + +// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. +const ( + pskModePlain uint8 = 0 + pskModeDHE uint8 = 1 +) + +// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved +// session. See RFC 8446, Section 4.2.11. +type pskIdentity struct { + label []byte + obfuscatedTicketAge uint32 +} + +// TLS Elliptic Curve Point Formats +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + certTypeRSASign = 1 + certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. +) + +// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with +// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. +const ( + signaturePKCS1v15 uint8 = iota + 225 + signatureRSAPSS + signatureECDSA + signatureEd25519 +) + +// directSigning is a standard Hash value that signals that no pre-hashing +// should be performed, and that the input should be signed directly. It is the +// hash function associated with the Ed25519 signature scheme. +var directSigning crypto.Hash = 0 + +// supportedSignatureAlgorithms contains the signature and hash algorithms that +// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ +// CertificateRequest. The two fields are merged to match with TLS 1.3. +// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. +var supportedSignatureAlgorithms = []SignatureScheme{ + PSSWithSHA256, + ECDSAWithP256AndSHA256, + Ed25519, + PSSWithSHA384, + PSSWithSHA512, + PKCS1WithSHA256, + PKCS1WithSHA384, + PKCS1WithSHA512, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + PKCS1WithSHA1, + ECDSAWithSHA1, +} + +// helloRetryRequestRandom is set as the Random value of a ServerHello +// to signal that the message is actually a HelloRetryRequest. +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + +const ( + // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server + // random as a downgrade protection if the server would be capable of + // negotiating a higher version. See RFC 8446, Section 4.1.3. + downgradeCanaryTLS12 = "DOWNGRD\x01" + downgradeCanaryTLS11 = "DOWNGRD\x00" +) + +// testingOnlyForceDowngradeCanary is set in tests to force the server side to +// include downgrade canaries even if it's using its highers supported version. +var testingOnlyForceDowngradeCanary bool + +type ConnectionState = tls.ConnectionState + +// ConnectionState records basic TLS details about the connection. +type connectionState struct { + // Version is the TLS version used by the connection (e.g. VersionTLS12). + Version uint16 + + // HandshakeComplete is true if the handshake has concluded. + HandshakeComplete bool + + // DidResume is true if this connection was successfully resumed from a + // previous session with a session ticket or similar mechanism. + DidResume bool + + // CipherSuite is the cipher suite negotiated for the connection (e.g. + // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). + CipherSuite uint16 + + // NegotiatedProtocol is the application protocol negotiated with ALPN. + NegotiatedProtocol string + + // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. + // + // Deprecated: this value is always true. + NegotiatedProtocolIsMutual bool + + // ServerName is the value of the Server Name Indication extension sent by + // the client. It's available both on the server and on the client side. + ServerName string + + // PeerCertificates are the parsed certificates sent by the peer, in the + // order in which they were sent. The first element is the leaf certificate + // that the connection is verified against. + // + // On the client side, it can't be empty. On the server side, it can be + // empty if Config.ClientAuth is not RequireAnyClientCert or + // RequireAndVerifyClientCert. + PeerCertificates []*x509.Certificate + + // VerifiedChains is a list of one or more chains where the first element is + // PeerCertificates[0] and the last element is from Config.RootCAs (on the + // client side) or Config.ClientCAs (on the server side). + // + // On the client side, it's set if Config.InsecureSkipVerify is false. On + // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven + // (and the peer provided a certificate) or RequireAndVerifyClientCert. + VerifiedChains [][]*x509.Certificate + + // SignedCertificateTimestamps is a list of SCTs provided by the peer + // through the TLS handshake for the leaf certificate, if any. + SignedCertificateTimestamps [][]byte + + // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) + // response provided by the peer for the leaf certificate, if any. + OCSPResponse []byte + + // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, + // Section 3). This value will be nil for TLS 1.3 connections and for all + // resumed connections. + // + // Deprecated: there are conditions in which this value might not be unique + // to a connection. See the Security Considerations sections of RFC 5705 and + // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. + TLSUnique []byte + + // ekm is a closure exposed via ExportKeyingMaterial. + ekm func(label string, context []byte, length int) ([]byte, error) +} + +type ConnectionStateWith0RTT struct { + ConnectionState + + Used0RTT bool // true if 0-RTT was both offered and accepted +} + +// ClientAuthType is tls.ClientAuthType +type ClientAuthType = tls.ClientAuthType + +const ( + NoClientCert = tls.NoClientCert + RequestClientCert = tls.RequestClientCert + RequireAnyClientCert = tls.RequireAnyClientCert + VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven + RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert +) + +// requiresClientCert reports whether the ClientAuthType requires a client +// certificate to be provided. +func requiresClientCert(c ClientAuthType) bool { + switch c { + case RequireAnyClientCert, RequireAndVerifyClientCert: + return true + default: + return false + } +} + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState = tls.ClientSessionState + +type clientSessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification + receivedAt time.Time // When the session ticket was received from the server + ocspResponse []byte // Stapled OCSP response presented by the server + scts [][]byte // SCTs presented by the server + + // TLS 1.3 fields. + nonce []byte // Ticket nonce sent by the server, to derive PSK + useBy time.Time // Expiration of the ticket lifetime as set by the server + ageAdd uint32 // Random obfuscation factor for sending the ticket age +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not +// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which +// are supported via this interface. +//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/marten-seemann/qtls-go1-15 ClientSessionCache" +type ClientSessionCache = tls.ClientSessionCache + +// SignatureScheme is a tls.SignatureScheme +type SignatureScheme = tls.SignatureScheme + +const ( + // RSASSA-PKCS1-v1_5 algorithms. + PKCS1WithSHA256 SignatureScheme = 0x0401 + PKCS1WithSHA384 SignatureScheme = 0x0501 + PKCS1WithSHA512 SignatureScheme = 0x0601 + + // RSASSA-PSS algorithms with public key OID rsaEncryption. + PSSWithSHA256 SignatureScheme = 0x0804 + PSSWithSHA384 SignatureScheme = 0x0805 + PSSWithSHA512 SignatureScheme = 0x0806 + + // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. + ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 + ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 + ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 + + // EdDSA algorithms. + Ed25519 SignatureScheme = 0x0807 + + // Legacy signature and hash algorithms for TLS 1.2. + PKCS1WithSHA1 SignatureScheme = 0x0201 + ECDSAWithSHA1 SignatureScheme = 0x0203 +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide application logic in the GetCertificate and GetConfigForClient callbacks. +type ClientHelloInfo = tls.ClientHelloInfo + +type clientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see RFC 4366, Section 3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see RFC 4492, Section 5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see RFC 4492, Section 5.1.2). + SupportedPoints []uint8 + + // SignatureSchemes lists the signature and hash schemes that the client + // is willing to verify. SignatureSchemes is set only if the Signature + // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). + SignatureSchemes []SignatureScheme + + // SupportedProtos lists the application protocols supported by the client. + // SupportedProtos is set only if the Application-Layer Protocol + // Negotiation Extension is being used (see RFC 7301, Section 3.1). + // + // Servers can select a protocol by setting Config.NextProtos in a + // GetConfigForClient return value. + SupportedProtos []string + + // SupportedVersions lists the TLS versions supported by the client. + // For TLS versions less than 1.3, this is extrapolated from the max + // version advertised by the client, so values other than the greatest + // might be rejected if used. + SupportedVersions []uint16 + + // Conn is the underlying net.Conn for the connection. Do not read + // from, or write to, this connection; that will cause the TLS + // connection to fail. + Conn net.Conn + + // config is embedded by the GetCertificate or GetConfigForClient caller, + // for use with SupportsCertificate. + config *Config +} + +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo = tls.CertificateRequestInfo + +type certificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte + + // SignatureSchemes lists the signature schemes that the server is + // willing to verify. + SignatureSchemes []SignatureScheme + + // Version is the TLS version that was negotiated for this connection. + Version uint16 +} + +// RenegotiationSupport enumerates the different levels of support for TLS +// renegotiation. TLS renegotiation is the act of performing subsequent +// handshakes on a connection after the first. This significantly complicates +// the state machine and has been the source of numerous, subtle security +// issues. Initiating a renegotiation is not supported, but support for +// accepting renegotiation requests may be enabled. +// +// Even when enabled, the server may not change its identity between handshakes +// (i.e. the leaf certificate must be the same). Additionally, concurrent +// handshake and application data flow is not permitted so renegotiation can +// only be used with protocols that synchronise with the renegotiation, such as +// HTTPS. +// +// Renegotiation is not defined in TLS 1.3. +type RenegotiationSupport = tls.RenegotiationSupport + +const ( + // RenegotiateNever disables renegotiation. + RenegotiateNever = tls.RenegotiateNever + + // RenegotiateOnceAsClient allows a remote server to request + // renegotiation once per connection. + RenegotiateOnceAsClient = tls.RenegotiateOnceAsClient + + // RenegotiateFreelyAsClient allows a remote server to repeatedly + // request renegotiation. + RenegotiateFreelyAsClient = tls.RenegotiateFreelyAsClient +) + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config = tls.Config + +type config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains to present to the + // other side of the connection. The first certificate compatible with the + // peer's requirements is selected automatically. + // + // Server configurations must set one of Certificates, GetCertificate or + // GetConfigForClient. Clients doing client-authentication may set either + // Certificates or GetClientCertificate. + // + // Note: if there are multiple Certificates, and they don't have the + // optional field Leaf set, certificate selection will incur a significant + // per-handshake performance cost. + Certificates []Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // + // Deprecated: NameToCertificate only allows associating a single + // certificate with a given name. Leave this field nil to let the library + // select the first compatible chain from Certificates. + NameToCertificate map[string]*Certificate + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // best element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // If SessionTicketKey was explicitly set on the returned Config, or if + // SetSessionTicketKeys was called on the returned Config, those keys will + // be used. Otherwise, the original Config keys will be used (and possibly + // rotated if they are automatically managed). + GetConfigForClient func(*ClientHelloInfo) (*Config, error) + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // VerifyConnection, if not nil, is called after normal certificate + // verification and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(ConnectionState) error + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported application level protocols, in + // order of preference. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting unless it is + // an IP address. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // InsecureSkipVerify controls whether a client verifies the server's + // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls + // accepts any certificate presented by the server and any host name in that + // certificate. In this mode, TLS is susceptible to machine-in-the-middle + // attacks unless custom verification is used. This should be used only for + // testing or in combination with VerifyConnection or VerifyPeerCertificate. + InsecureSkipVerify bool + + // CipherSuites is a list of supported cipher suites for TLS versions up to + // TLS 1.2. If CipherSuites is nil, a default list of secure cipher suites + // is used, with a preference order based on hardware performance. The + // default cipher suites might change over Go versions. Note that TLS 1.3 + // ciphersuites are not configurable. + CipherSuites []uint16 + + // PreferServerCipherSuites controls whether the server selects the + // client's most preferred ciphersuite, or the server's most preferred + // ciphersuite. If true then the server's preference, as expressed in + // the order of elements in CipherSuites, is used. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket and + // PSK (resumption) support. Note that on clients, session ticket support is + // also disabled if ClientSessionCache is nil. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session resumption. + // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled + // with random data before the first server handshake. + // + // Deprecated: if this field is left at zero, session ticket keys will be + // automatically rotated every day and dropped after seven days. For + // customizing the rotation schedule or synchronizing servers that are + // terminating connections for the same host, use SetSessionTicketKeys. + SessionTicketKey [32]byte + + // ClientSessionCache is a cache of ClientSessionState entries for TLS + // session resumption. It is only used by clients. + ClientSessionCache ClientSessionCache + + // MinVersion contains the minimum TLS version that is acceptable. + // If zero, TLS 1.0 is currently taken as the minimum. + MinVersion uint16 + + // MaxVersion contains the maximum TLS version that is acceptable. + // If zero, the maximum version supported by this package is used, + // which is currently TLS 1.3. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. The client will use the first preference as the type for + // its key share in TLS 1.3. This may change in the future. + CurvePreferences []CurveID + + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. + // When true, the largest possible TLS record size is always used. When + // false, the size of TLS records may be adjusted in an attempt to + // improve latency. + DynamicRecordSizingDisabled bool + + // Renegotiation controls what types of renegotiation are supported. + // The default, none, is correct for the vast majority of applications. + Renegotiation RenegotiationSupport + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // mutex protects sessionTicketKeys and autoSessionTicketKeys. + mutex sync.RWMutex + // sessionTicketKeys contains zero or more ticket keys. If set, it means the + // the keys were set with SessionTicketKey or SetSessionTicketKeys. The + // first key is used for new tickets and any subsequent keys can be used to + // decrypt old tickets. The slice contents are not protected by the mutex + // and are immutable. + sessionTicketKeys []ticketKey + // autoSessionTicketKeys is like sessionTicketKeys but is owned by the + // auto-rotation logic. See Config.ticketKeys. + autoSessionTicketKeys []ticketKey +} + +// A RecordLayer handles encrypting and decrypting of TLS messages. +type RecordLayer interface { + SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) + SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) + ReadHandshakeMessage() ([]byte, error) + WriteRecord([]byte) (int, error) + SendAlert(uint8) +} + +type ExtraConfig struct { + // GetExtensions, if not nil, is called before a message that allows + // sending of extensions is sent. + // Currently only implemented for the ClientHello message (for the client) + // and for the EncryptedExtensions message (for the server). + // Only valid for TLS 1.3. + GetExtensions func(handshakeMessageType uint8) []Extension + + // ReceivedExtensions, if not nil, is called when a message that allows the + // inclusion of extensions is received. + // It is called with an empty slice of extensions, if the message didn't + // contain any extensions. + // Currently only implemented for the ClientHello message (sent by the + // client) and for the EncryptedExtensions message (sent by the server). + // Only valid for TLS 1.3. + ReceivedExtensions func(handshakeMessageType uint8, exts []Extension) + + // AlternativeRecordLayer is used by QUIC + AlternativeRecordLayer RecordLayer + + // Enforce the selection of a supported application protocol. + // Only works for TLS 1.3. + // If enabled, client and server have to agree on an application protocol. + // Otherwise, connection establishment fails. + EnforceNextProtoSelection bool + + // If MaxEarlyData is greater than 0, the client will be allowed to send early + // data when resuming a session. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning on the client. + MaxEarlyData uint32 + + // The Accept0RTT callback is called when the client offers 0-RTT. + // The server then has to decide if it wants to accept or reject 0-RTT. + // It is only used for servers. + Accept0RTT func(appData []byte) bool + + // 0RTTRejected is called when the server rejectes 0-RTT. + // It is only used for clients. + Rejected0RTT func() + + // If set, the client will export the 0-RTT key when resuming a session that + // allows sending of early data. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning to the server. + Enable0RTT bool + + // Is called when the client saves a session ticket to the session ticket. + // This gives the application the opportunity to save some data along with the ticket, + // which can be restored when the session ticket is used. + GetAppDataForSessionState func() []byte + + // Is called when the client uses a session ticket. + // Restores the application data that was saved earlier on GetAppDataForSessionTicket. + SetAppDataFromSessionState func([]byte) +} + +// Clone clones. +func (c *ExtraConfig) Clone() *ExtraConfig { + return &ExtraConfig{ + GetExtensions: c.GetExtensions, + ReceivedExtensions: c.ReceivedExtensions, + AlternativeRecordLayer: c.AlternativeRecordLayer, + EnforceNextProtoSelection: c.EnforceNextProtoSelection, + MaxEarlyData: c.MaxEarlyData, + Enable0RTT: c.Enable0RTT, + Accept0RTT: c.Accept0RTT, + Rejected0RTT: c.Rejected0RTT, + GetAppDataForSessionState: c.GetAppDataForSessionState, + SetAppDataFromSessionState: c.SetAppDataFromSessionState, + } +} + +func (c *ExtraConfig) usesAlternativeRecordLayer() bool { + return c != nil && c.AlternativeRecordLayer != nil +} + +const ( + // ticketKeyNameLen is the number of bytes of identifier that is prepended to + // an encrypted session ticket in order to identify the key used to encrypt it. + ticketKeyNameLen = 16 + + // ticketKeyLifetime is how long a ticket key remains valid and can be used to + // resume a client connection. + ticketKeyLifetime = 7 * 24 * time.Hour // 7 days + + // ticketKeyRotation is how often the server should rotate the session ticket key + // that is used for new tickets. + ticketKeyRotation = 24 * time.Hour +) + +// ticketKey is the internal representation of a session ticket key. +type ticketKey struct { + // keyName is an opaque byte string that serves to identify the session + // ticket key. It's exposed as plaintext in every session ticket. + keyName [ticketKeyNameLen]byte + aesKey [16]byte + hmacKey [16]byte + // created is the time at which this ticket key was created. See Config.ticketKeys. + created time.Time +} + +// ticketKeyFromBytes converts from the external representation of a session +// ticket key to a ticketKey. Externally, session ticket keys are 32 random +// bytes and this function expands that into sufficient name and key material. +func (c *config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { + hashed := sha512.Sum512(b[:]) + copy(key.keyName[:], hashed[:ticketKeyNameLen]) + copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) + copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) + key.created = c.time() + return key +} + +// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session +// ticket, and the lifetime we set for tickets we send. +const maxSessionTicketLifetime = 7 * 24 * time.Hour + +// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *config) Clone() *config { + if c == nil { + return nil + } + c.mutex.RLock() + defer c.mutex.RUnlock() + return &config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + VerifyConnection: c.VerifyConnection, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + sessionTicketKeys: c.sessionTicketKeys, + autoSessionTicketKeys: c.autoSessionTicketKeys, + } +} + +// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was +// randomized for backwards compatibility but is not in use. +var deprecatedSessionTicketKey = []byte("DEPRECATED") + +// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is +// randomized if empty, and that sessionTicketKeys is populated from it otherwise. +func (c *config) initLegacySessionTicketKeyRLocked() { + // Don't write if SessionTicketKey is already defined as our deprecated string, + // or if it is defined by the user but sessionTicketKeys is already set. + if c.SessionTicketKey != [32]byte{} && + (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { + return + } + + // We need to write some data, so get an exclusive lock and re-check any conditions. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + if c.SessionTicketKey == [32]byte{} { + if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) + } + // Write the deprecated prefix at the beginning so we know we created + // it. This key with the DEPRECATED prefix isn't used as an actual + // session ticket key, and is only randomized in case the application + // reuses it for some reason. + copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) + } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { + c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} + } + +} + +// ticketKeys returns the ticketKeys for this connection. +// If configForClient has explicitly set keys, those will +// be returned. Otherwise, the keys on c will be used and +// may be rotated if auto-managed. +// During rotation, any expired session ticket keys are deleted from +// c.sessionTicketKeys. If the session ticket key that is currently +// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) +// is not fresh, then a new session ticket key will be +// created and prepended to c.sessionTicketKeys. +func (c *config) ticketKeys(configForClient *config) []ticketKey { + // If the ConfigForClient callback returned a Config with explicitly set + // keys, use those, otherwise just use the original Config. + if configForClient != nil { + configForClient.mutex.RLock() + if configForClient.SessionTicketsDisabled { + return nil + } + configForClient.initLegacySessionTicketKeyRLocked() + if len(configForClient.sessionTicketKeys) != 0 { + ret := configForClient.sessionTicketKeys + configForClient.mutex.RUnlock() + return ret + } + configForClient.mutex.RUnlock() + } + + c.mutex.RLock() + defer c.mutex.RUnlock() + if c.SessionTicketsDisabled { + return nil + } + c.initLegacySessionTicketKeyRLocked() + if len(c.sessionTicketKeys) != 0 { + return c.sessionTicketKeys + } + // Fast path for the common case where the key is fresh enough. + if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { + return c.autoSessionTicketKeys + } + + // autoSessionTicketKeys are managed by auto-rotation. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + // Re-check the condition in case it changed since obtaining the new lock. + if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { + var newKey [32]byte + if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { + panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) + } + valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) + valid = append(valid, c.ticketKeyFromBytes(newKey)) + for _, k := range c.autoSessionTicketKeys { + // While rotating the current key, also remove any expired ones. + if c.time().Sub(k.created) < ticketKeyLifetime { + valid = append(valid, k) + } + } + c.autoSessionTicketKeys = valid + } + return c.autoSessionTicketKeys +} + +// SetSessionTicketKeys updates the session ticket keys for a server. +// +// The first key will be used when creating new tickets, while all keys can be +// used for decrypting tickets. It is safe to call this function while the +// server is running in order to rotate the session ticket keys. The function +// will panic if keys is empty. +// +// Calling this function will turn off automatic session ticket key rotation. +// +// If multiple servers are terminating connections for the same host they should +// all have the same session ticket keys. If the session ticket keys leaks, +// previously recorded and future TLS connections using those keys might be +// compromised. +func (c *config) SetSessionTicketKeys(keys [][32]byte) { + if len(keys) == 0 { + panic("tls: keys must have at least one key") + } + + newKeys := make([]ticketKey, len(keys)) + for i, bytes := range keys { + newKeys[i] = c.ticketKeyFromBytes(bytes) + } + + c.mutex.Lock() + c.sessionTicketKeys = newKeys + c.mutex.Unlock() +} + +func (c *config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func (c *config) cipherSuites() []uint16 { + s := c.CipherSuites + if s == nil { + s = defaultCipherSuites() + } + return s +} + +var supportedVersions = []uint16{ + VersionTLS13, + VersionTLS12, + VersionTLS11, + VersionTLS10, +} + +func (c *config) supportedVersions() []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if c != nil && c.MinVersion != 0 && v < c.MinVersion { + continue + } + if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +func (c *config) maxSupportedVersion() uint16 { + supportedVersions := c.supportedVersions() + if len(supportedVersions) == 0 { + return 0 + } + return supportedVersions[0] +} + +// supportedVersionsFromMax returns a list of supported versions derived from a +// legacy maximum version value. Note that only versions supported by this +// library are returned. Any newer peer will use supportedVersions anyway. +func supportedVersionsFromMax(maxVersion uint16) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if v > maxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *config) curvePreferences() []CurveID { + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +func (c *config) supportsCurve(curve CurveID) bool { + for _, cc := range c.curvePreferences() { + if cc == curve { + return true + } + } + return false +} + +// mutualVersion returns the protocol version to use given the advertised +// versions of the peer. Priority is given to the peer preference order. +func (c *config) mutualVersion(peerVersions []uint16) (uint16, bool) { + supportedVersions := c.supportedVersions() + for _, peerVersion := range peerVersions { + for _, v := range supportedVersions { + if v == peerVersion { + return v, true + } + } + } + return 0, false +} + +var errNoCertificates = errors.New("tls: no certificates configured") + +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } + } + + if len(c.Certificates) == 0 { + return nil, errNoCertificates + } + + if len(c.Certificates) == 1 { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + if c.NameToCertificate != nil { + name := strings.ToLower(clientHello.ServerName) + if cert, ok := c.NameToCertificate[name]; ok { + return cert, nil + } + if len(name) > 0 { + labels := strings.Split(name, ".") + labels[0] = "*" + wildcardName := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[wildcardName]; ok { + return cert, nil + } + } + } + + for _, cert := range c.Certificates { + if err := clientHello.SupportsCertificate(&cert); err == nil { + return &cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0], nil +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the client that sent the ClientHello. Otherwise, it returns an error +// describing the reason for the incompatibility. +// +// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate +// callback, this method will take into account the associated Config. Note that +// if GetConfigForClient returns a different Config, the change can't be +// accounted for by this method. +// +// This function will call x509.ParseCertificate unless c.Leaf is set, which can +// incur a significant performance cost. +func (chi *clientHelloInfo) SupportsCertificate(c *Certificate) error { + // Note we don't currently support certificate_authorities nor + // signature_algorithms_cert, and don't check the algorithms of the + // signatures on the chain (which anyway are a SHOULD, see RFC 8446, + // Section 4.4.2.2). + + config := chi.config + if config == nil { + config = &Config{} + } + conf := fromConfig(config) + vers, ok := conf.mutualVersion(chi.SupportedVersions) + if !ok { + return errors.New("no mutually supported protocol versions") + } + + // If the client specified the name they are trying to connect to, the + // certificate needs to be valid for it. + if chi.ServerName != "" { + x509Cert, err := leafCertificate(c) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { + return fmt.Errorf("certificate is not valid for requested server name: %w", err) + } + } + + // supportsRSAFallback returns nil if the certificate and connection support + // the static RSA key exchange, and unsupported otherwise. The logic for + // supporting static RSA is completely disjoint from the logic for + // supporting signed key exchanges, so we just check it as a fallback. + supportsRSAFallback := func(unsupported error) error { + // TLS 1.3 dropped support for the static RSA key exchange. + if vers == VersionTLS13 { + return unsupported + } + // The static RSA key exchange works by decrypting a challenge with the + // RSA private key, not by signing, so check the PrivateKey implements + // crypto.Decrypter, like *rsa.PrivateKey does. + if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { + if _, ok := priv.Public().(*rsa.PublicKey); !ok { + return unsupported + } + } else { + return unsupported + } + // Finally, there needs to be a mutual cipher suite that uses the static + // RSA key exchange instead of ECDHE. + rsaCipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + return false + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if rsaCipherSuite == nil { + return unsupported + } + return nil + } + + // If the client sent the signature_algorithms extension, ensure it supports + // schemes we can use with this certificate and TLS version. + if len(chi.SignatureSchemes) > 0 { + if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { + return supportsRSAFallback(err) + } + } + + // In TLS 1.3 we are done because supported_groups is only relevant to the + // ECDHE computation, point format negotiation is removed, cipher suites are + // only relevant to the AEAD choice, and static RSA does not exist. + if vers == VersionTLS13 { + return nil + } + + // The only signed key exchange we support is ECDHE. + if !supportsECDHE(conf, chi.SupportedCurves, chi.SupportedPoints) { + return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) + } + + var ecdsaCipherSuite bool + if priv, ok := c.PrivateKey.(crypto.Signer); ok { + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + var curve CurveID + switch pub.Curve { + case elliptic.P256(): + curve = CurveP256 + case elliptic.P384(): + curve = CurveP384 + case elliptic.P521(): + curve = CurveP521 + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + var curveOk bool + for _, c := range chi.SupportedCurves { + if c == curve && conf.supportsCurve(c) { + curveOk = true + break + } + } + if !curveOk { + return errors.New("client doesn't support certificate curve") + } + ecdsaCipherSuite = true + case ed25519.PublicKey: + if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { + return errors.New("connection doesn't support Ed25519") + } + ecdsaCipherSuite = true + case *rsa.PublicKey: + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + } else { + return supportsRSAFallback(unsupportedCertificateError(c)) + } + + // Make sure that there is a mutually supported cipher suite that works with + // this certificate. Cipher suite selection will then apply the logic in + // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. + cipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE == 0 { + return false + } + if c.flags&suiteECSign != 0 { + if !ecdsaCipherSuite { + return false + } + } else { + if ecdsaCipherSuite { + return false + } + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if cipherSuite == nil { + return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) + } + + return nil +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +// +// Deprecated: NameToCertificate only allows associating a single certificate +// with a given name. Leave that field nil to let the library select the first +// compatible chain from Certificates. +func (c *config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := leafCertificate(cert) + if err != nil { + continue + } + // If SANs are *not* present, some clients will consider the certificate + // valid for the name in the Common Name. + if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +const ( + keyLogLabelTLS12 = "CLIENT_RANDOM" + keyLogLabelEarlyTraffic = "CLIENT_EARLY_TRAFFIC_SECRET" + keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" + keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" +) + +func (c *config) writeKeyLog(label string, clientRandom, secret []byte) error { + if c.KeyLogWriter == nil { + return nil + } + + logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) + + writerMutex.Lock() + _, err := c.KeyLogWriter.Write(logLine) + writerMutex.Unlock() + + return err +} + +// writerMutex protects all KeyLogWriters globally. It is rarely enabled, +// and is only for debugging, so a global mutex saves space. +var writerMutex sync.Mutex + +// A Certificate is a chain of one or more certificates, leaf first. +type Certificate = tls.Certificate + +// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing +// the corresponding c.Certificate[0]. +func leafCertificate(c *Certificate) (*x509.Certificate, error) { + if c.Leaf != nil { + return c.Leaf, nil + } + return x509.ParseCertificate(c.Certificate[0]) +} + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) bool +} + +// lruSessionCache is a ClientSessionCache implementation that uses an LRU +// caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state *ClientSessionState +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + } +} + +// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry +// corresponding to sessionKey is removed from the cache instead. +func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + if cs == nil { + c.q.Remove(elem) + delete(c.m, sessionKey) + } else { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + } + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the ClientSessionState value associated with a given key. It +// returns (nil, false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +var ( + once sync.Once + varDefaultCipherSuites []uint16 + varDefaultCipherSuitesTLS13 []uint16 +) + +func defaultCipherSuites() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuites +} + +func defaultCipherSuitesTLS13() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuitesTLS13 +} + +func initDefaultCipherSuites() { + var topCipherSuites []uint16 + + if hasAESGCMHardwareSupport { + // If AES-GCM hardware is provided then prioritise AES-GCM + // cipher suites. + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_256_GCM_SHA384, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topCipherSuites = []uint16{ + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + } + } + + varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites)) + varDefaultCipherSuites = append(varDefaultCipherSuites, topCipherSuites...) + +NextCipherSuite: + for _, suite := range cipherSuites { + if suite.flags&suiteDefaultOff != 0 { + continue + } + for _, existing := range varDefaultCipherSuites { + if existing == suite.id { + continue NextCipherSuite + } + } + varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id) + } +} + +func unexpectedMessageError(wanted, got interface{}) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { + for _, s := range supportedSignatureAlgorithms { + if s == sigAlg { + return true + } + } + return false +} + +var aesgcmCiphers = map[uint16]bool{ + // 1.2 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, + // 1.3 + TLS_AES_128_GCM_SHA256: true, + TLS_AES_256_GCM_SHA384: true, +} + +var nonAESGCMAEADCiphers = map[uint16]bool{ + // 1.2 + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, + // 1.3 + TLS_CHACHA20_POLY1305_SHA256: true, +} + +// aesgcmPreferred returns whether the first valid cipher in the preference list +// is an AES-GCM cipher, implying the peer has hardware support for it. +func aesgcmPreferred(ciphers []uint16) bool { + for _, cID := range ciphers { + c := cipherSuiteByID(cID) + if c == nil { + c13 := cipherSuiteTLS13ByID(cID) + if c13 == nil { + continue + } + return aesgcmCiphers[cID] + } + return aesgcmCiphers[cID] + } + return false +} + +// deprioritizeAES reorders cipher preference lists by rearranging +// adjacent AEAD ciphers such that AES-GCM based ciphers are moved +// after other AEAD ciphers. It returns a fresh slice. +func deprioritizeAES(ciphers []uint16) []uint16 { + reordered := make([]uint16, len(ciphers)) + copy(reordered, ciphers) + sort.SliceStable(reordered, func(i, j int) bool { + return nonAESGCMAEADCiphers[reordered[i]] && aesgcmCiphers[reordered[j]] + }) + return reordered +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/common_js.go b/vendor/github.com/marten-seemann/qtls-go1-16/common_js.go new file mode 100644 index 00000000..97e6ecef --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/common_js.go @@ -0,0 +1,12 @@ +// +build js + +package qtls + +var ( + hasGCMAsmAMD64 = false + hasGCMAsmARM64 = false + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = false + + hasAESGCMHardwareSupport = false +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/common_nojs.go b/vendor/github.com/marten-seemann/qtls-go1-16/common_nojs.go new file mode 100644 index 00000000..5e56e0fb --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/common_nojs.go @@ -0,0 +1,20 @@ +// +build !js + +package qtls + +import ( + "runtime" + + "golang.org/x/sys/cpu" +) + +var ( + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || + runtime.GOARCH == "arm64" && hasGCMAsmARM64 || + runtime.GOARCH == "s390x" && hasGCMAsmS390X +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/conn.go b/vendor/github.com/marten-seemann/qtls-go1-16/conn.go new file mode 100644 index 00000000..fa5eb3f1 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/conn.go @@ -0,0 +1,1536 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package qtls + +import ( + "bytes" + "crypto/cipher" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn net.Conn + isClient bool + handshakeFn func() error // (*Conn).clientHandshake or serverHandshake + + // handshakeStatus is 1 if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). + // This field is only to be accessed with sync/atomic. + handshakeStatus uint32 + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *config // configuration passed to constructor + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + extraConfig *ExtraConfig + + handshakes int + didResume bool // whether this connection was a session resumption + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // signed certificate timestamps from server + peerCertificates []*x509.Certificate + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // serverName contains the server name indicated by the client, if any. + serverName string + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + // ekm is a closure for exporting keying material. + ekm func(label string, context []byte, length int) ([]byte, error) + // For the client: + // resumptionSecret is the resumption_master_secret for handling + // NewSessionTicket messages. nil if config.SessionTicketsDisabled. + // For the server: + // resumptionSecret is the resumption_master_secret for generating + // NewSessionTicket messages. Only used when the alternative record + // layer is set. nil if config.SessionTicketsDisabled. + resumptionSecret []byte + + // ticketKeys is the set of active session ticket keys for this + // connection. The first one is used to encrypt new tickets and + // all are tried to decrypt tickets. + ticketKeys []ticketKey + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte + + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string + + // input/output + in, out halfConn + rawInput bytes.Buffer // raw input, starting with a record header + input bytes.Reader // application data waiting to be read, from rawInput.Next + hand bytes.Buffer // handshake data waiting to be read + buffering bool // whether records are buffered in sendBuf + sendBuf []byte // a buffer of records waiting to be sent + + // bytesSent counts the bytes of application data sent. + // packetsSent counts packets. + bytesSent int64 + packetsSent int64 + + // retryCount counts the number of consecutive non-advancing records + // received by Conn.readRecord. That is, records that neither advance the + // handshake, nor deliver application data. Protected by in.Mutex. + retryCount int + + // activeCall is an atomic int32; the low bit is whether Close has + // been called. the rest of the bits are the number of goroutines + // in Conn.Write. + activeCall int32 + + used0RTT bool + + tmp [16]byte +} + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape + + nextCipher interface{} // next encryption state + nextMac hash.Hash // next MAC algorithm + + trafficSecret []byte // current TLS 1.3 traffic secret + + setKeyCallback func(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) +} + +type permanentError struct { + err net.Error +} + +func (e *permanentError) Error() string { return e.err.Error() } +func (e *permanentError) Unwrap() error { return e.err } +func (e *permanentError) Timeout() bool { return e.err.Timeout() } +func (e *permanentError) Temporary() bool { return false } + +func (hc *halfConn) setErrorLocked(err error) error { + if e, ok := err.(net.Error); ok { + hc.err = &permanentError{err: e} + } else { + hc.err = err + } + return hc.err +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { + hc.version = version + hc.nextCipher = cipher + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec() error { + if hc.nextCipher == nil || hc.version == VersionTLS13 { + return alertInternalError + } + hc.cipher = hc.nextCipher + hc.mac = hc.nextMac + hc.nextCipher = nil + hc.nextMac = nil + for i := range hc.seq { + hc.seq[i] = 0 + } + return nil +} + +func (hc *halfConn) exportKey(encLevel EncryptionLevel, suite *cipherSuiteTLS13, trafficSecret []byte) { + if hc.setKeyCallback != nil { + s := &CipherSuiteTLS13{ + ID: suite.id, + KeyLen: suite.keyLen, + Hash: suite.hash, + AEAD: func(key, fixedNonce []byte) cipher.AEAD { return suite.aead(key, fixedNonce) }, + } + hc.setKeyCallback(encLevel, s, trafficSecret) + } +} + +func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { + hc.trafficSecret = secret + key, iv := suite.trafficKey(secret) + hc.cipher = suite.aead(key, iv) + for i := range hc.seq { + hc.seq[i] = 0 + } +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq() { + for i := 7; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +// explicitNonceLen returns the number of bytes of explicit nonce or IV included +// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 +// and in certain AEAD modes in TLS 1.2. +func (hc *halfConn) explicitNonceLen() int { + if hc.cipher == nil { + return 0 + } + + switch c := hc.cipher.(type) { + case cipher.Stream: + return 0 + case aead: + return c.explicitNonceLen() + case cbcMode: + // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. + if hc.version >= VersionTLS11 { + return c.BlockSize() + } + return 0 + default: + panic("unknown cipher type") + } +} + +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. +func extractPadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) + + // The maximum possible padding length plus the actual length field + toCheck := 256 + // The length of the padded data is public, so we can use an if here + if toCheck > len(payload) { + toCheck = len(payload) + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + // Zero the padding length on error. This ensures any unchecked bytes + // are included in the MAC. Otherwise, an attacker that could + // distinguish MAC failures from padding failures could mount an attack + // similar to POODLE in SSL 3.0: given a good ciphertext that uses a + // full block's worth of padding, replace the final block with another + // block. If the MAC check passed but the padding check failed, the + // last byte of that block decrypted to the block size. + // + // See also macAndPaddingGood logic below. + paddingLen &= good + + toRemove = int(paddingLen) + 1 + return +} + +func roundUp(a, b int) int { + return a + (b-a%b)%b +} + +// cbcMode is an interface for block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// decrypt authenticates and decrypts the record if protection is active at +// this stage. The returned plaintext might overlap with the input. +func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { + var plaintext []byte + typ := recordType(record[0]) + payload := record[recordHeaderLen:] + + // In TLS 1.3, change_cipher_spec messages are to be ignored without being + // decrypted. See RFC 8446, Appendix D.4. + if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { + return payload, typ, nil + } + + paddingGood := byte(255) + paddingLen := 0 + + explicitNonceLen := hc.explicitNonceLen() + + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case aead: + if len(payload) < explicitNonceLen { + return nil, 0, alertBadRecordMAC + } + nonce := payload[:explicitNonceLen] + if len(nonce) == 0 { + nonce = hc.seq[:] + } + payload = payload[explicitNonceLen:] + + var additionalData []byte + if hc.version == VersionTLS13 { + additionalData = record[:recordHeaderLen] + } else { + additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:3]...) + n := len(payload) - c.Overhead() + additionalData = append(additionalData, byte(n>>8), byte(n)) + } + + var err error + plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) + if err != nil { + return nil, 0, alertBadRecordMAC + } + case cbcMode: + blockSize := c.BlockSize() + minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) + if len(payload)%blockSize != 0 || len(payload) < minPayload { + return nil, 0, alertBadRecordMAC + } + + if explicitNonceLen > 0 { + c.SetIV(payload[:explicitNonceLen]) + payload = payload[explicitNonceLen:] + } + c.CryptBlocks(payload, payload) + + // In a limited attempt to protect against CBC padding oracles like + // Lucky13, the data past paddingLen (which is secret) is passed to + // the MAC function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC roughly constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write, modulo cache effects. + paddingLen, paddingGood = extractPadding(payload) + default: + panic("unknown cipher type") + } + + if hc.version == VersionTLS13 { + if typ != recordTypeApplicationData { + return nil, 0, alertUnexpectedMessage + } + if len(plaintext) > maxPlaintext+1 { + return nil, 0, alertRecordOverflow + } + // Remove padding and find the ContentType scanning from the end. + for i := len(plaintext) - 1; i >= 0; i-- { + if plaintext[i] != 0 { + typ = recordType(plaintext[i]) + plaintext = plaintext[:i] + break + } + if i == 0 { + return nil, 0, alertUnexpectedMessage + } + } + } + } else { + plaintext = payload + } + + if hc.mac != nil { + macSize := hc.mac.Size() + if len(payload) < macSize { + return nil, 0, alertBadRecordMAC + } + + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } + record[3] = byte(n >> 8) + record[4] = byte(n) + remoteMAC := payload[n : n+macSize] + localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + + // This is equivalent to checking the MACs and paddingGood + // separately, but in constant-time to prevent distinguishing + // padding failures from MAC failures. Depending on what value + // of paddingLen was returned on bad padding, distinguishing + // bad MAC from bad padding can lead to an attack. + // + // See also the logic at the end of extractPadding. + macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) + if macAndPaddingGood != 1 { + return nil, 0, alertBadRecordMAC + } + + plaintext = payload[:n] + } + + hc.incSeq() + return plaintext, typ, nil +} + +func (c *Conn) setAlternativeRecordLayer() { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + c.in.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetReadKey + c.out.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetWriteKey + } +} + +// sliceForAppend extends the input slice by n bytes. head is the full extended +// slice, while tail is the appended part. If the original slice has sufficient +// capacity no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and +// appends it to record, which must already contain the record header. +func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { + if hc.cipher == nil { + return append(record, payload...), nil + } + + var explicitNonce []byte + if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { + record, explicitNonce = sliceForAppend(record, explicitNonceLen) + if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { + // The AES-GCM construction in TLS has an explicit nonce so that the + // nonce can be random. However, the nonce is only 8 bytes which is + // too small for a secure, random nonce. Therefore we use the + // sequence number as the nonce. The 3DES-CBC construction also has + // an 8 bytes nonce but its nonces must be unpredictable (see RFC + // 5246, Appendix F.3), forcing us to use randomness. That's not + // 3DES' biggest problem anyway because the birthday bound on block + // collision is reached first due to its similarly small block size + // (see the Sweet32 attack). + copy(explicitNonce, hc.seq[:]) + } else { + if _, err := io.ReadFull(rand, explicitNonce); err != nil { + return nil, err + } + } + } + + var dst []byte + switch c := hc.cipher.(type) { + case cipher.Stream: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + record, dst = sliceForAppend(record, len(payload)+len(mac)) + c.XORKeyStream(dst[:len(payload)], payload) + c.XORKeyStream(dst[len(payload):], mac) + case aead: + nonce := explicitNonce + if len(nonce) == 0 { + nonce = hc.seq[:] + } + + if hc.version == VersionTLS13 { + record = append(record, payload...) + + // Encrypt the actual ContentType and replace the plaintext one. + record = append(record, record[0]) + record[0] = byte(recordTypeApplicationData) + + n := len(payload) + 1 + c.Overhead() + record[3] = byte(n >> 8) + record[4] = byte(n) + + record = c.Seal(record[:recordHeaderLen], + nonce, record[recordHeaderLen:], record[:recordHeaderLen]) + } else { + additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:recordHeaderLen]...) + record = c.Seal(record, nonce, payload, additionalData) + } + case cbcMode: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + blockSize := c.BlockSize() + plaintextLen := len(payload) + len(mac) + paddingLen := blockSize - plaintextLen%blockSize + record, dst = sliceForAppend(record, plaintextLen+paddingLen) + copy(dst, payload) + copy(dst[len(payload):], mac) + for i := plaintextLen; i < len(dst); i++ { + dst[i] = byte(paddingLen - 1) + } + if len(explicitNonce) > 0 { + c.SetIV(explicitNonce) + } + c.CryptBlocks(dst, dst) + default: + panic("unknown cipher type") + } + + // Update length to include nonce, MAC and any block padding needed. + n := len(record) - recordHeaderLen + record[3] = byte(n >> 8) + record[4] = byte(n) + hc.incSeq() + + return record, nil +} + +// RecordHeaderError is returned when a TLS record header is invalid. +type RecordHeaderError struct { + // Msg contains a human readable string that describes the error. + Msg string + // RecordHeader contains the five bytes of TLS record header that + // triggered the error. + RecordHeader [5]byte + // Conn provides the underlying net.Conn in the case that a client + // sent an initial handshake that didn't look like TLS. + // It is nil if there's already been a handshake or a TLS alert has + // been written to the connection. + Conn net.Conn +} + +func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } + +func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { + err.Msg = msg + err.Conn = conn + copy(err.RecordHeader[:], c.rawInput.Bytes()) + return err +} + +func (c *Conn) readRecord() error { + return c.readRecordOrCCS(false) +} + +func (c *Conn) readChangeCipherSpec() error { + return c.readRecordOrCCS(true) +} + +// readRecordOrCCS reads one or more TLS records from the connection and +// updates the record layer state. Some invariants: +// * c.in must be locked +// * c.input must be empty +// During the handshake one and only one of the following will happen: +// - c.hand grows +// - c.in.changeCipherSpec is called +// - an error is returned +// After the handshake one and only one of the following will happen: +// - c.hand grows +// - c.input is set +// - an error is returned +func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { + if c.in.err != nil { + return c.in.err + } + handshakeComplete := c.handshakeComplete() + + // This function modifies c.rawInput, which owns the c.input memory. + if c.input.Len() != 0 { + return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data")) + } + c.input.Reset(nil) + + // Read header, payload. + if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { + // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify + // is an error, but popular web sites seem to do this, so we accept it + // if and only if at the record boundary. + if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 { + err = io.EOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + hdr := c.rawInput.Bytes()[:recordHeaderLen] + typ := recordType(hdr[0]) + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + if !handshakeComplete && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) + } + + vers := uint16(hdr[1])<<8 | uint16(hdr[2]) + n := int(hdr[3])<<8 | int(hdr[4]) + if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { + c.sendAlert(alertProtocolVersion) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if !c.haveVers { + // First message, be extra suspicious: this might not be a TLS + // client. Bail out before reading a full 'body', if possible. + // The current max version is 3.3 so if the version is >= 16.0, + // it's probably not real. + if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { + return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) + } + } + if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + + // Process message. + record := c.rawInput.Next(recordHeaderLen + n) + data, typ, err := c.in.decrypt(record) + if err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + if len(data) > maxPlaintext { + return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) + } + + // Application Data messages are always protected. + if c.in.cipher == nil && typ == recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { + // This is a state-advancing message: reset the retry count. + c.retryCount = 0 + } + + // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. + if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + switch typ { + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + + case recordTypeAlert: + if len(data) != 2 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if alert(data[1]) == alertCloseNotify { + return c.in.setErrorLocked(io.EOF) + } + if c.vers == VersionTLS13 { + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + } + switch data[0] { + case alertLevelWarning: + // Drop the record on the floor and retry. + return c.retryReadRecord(expectChangeCipherSpec) + case alertLevelError: + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if len(data) != 1 || data[0] != 1 { + return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) + } + // Handshake messages are not allowed to fragment across the CCS. + if c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // In TLS 1.3, change_cipher_spec records are ignored until the + // Finished. See RFC 8446, Appendix D.4. Note that according to Section + // 5, a server can send a ChangeCipherSpec before its ServerHello, when + // c.vers is still unset. That's not useful though and suspicious if the + // server then selects a lower protocol version, so don't allow that. + if c.vers == VersionTLS13 { + return c.retryReadRecord(expectChangeCipherSpec) + } + if !expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if err := c.in.changeCipherSpec(); err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + + case recordTypeApplicationData: + if !handshakeComplete || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // Some OpenSSL servers send empty records in order to randomize the + // CBC IV. Ignore a limited number of empty records. + if len(data) == 0 { + return c.retryReadRecord(expectChangeCipherSpec) + } + // Note that data is owned by c.rawInput, following the Next call above, + // to avoid copying the plaintext. This is safe because c.rawInput is + // not read from or written to until c.input is drained. + c.input.Reset(data) + + case recordTypeHandshake: + if len(data) == 0 || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.hand.Write(data) + } + + return nil +} + +// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like +// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. +func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many ignored records")) + } + return c.readRecordOrCCS(expectChangeCipherSpec) +} + +// atLeastReader reads from R, stopping with EOF once at least N bytes have been +// read. It is different from an io.LimitedReader in that it doesn't cut short +// the last Read call, and in that it considers an early EOF an error. +type atLeastReader struct { + R io.Reader + N int64 +} + +func (r *atLeastReader) Read(p []byte) (int, error) { + if r.N <= 0 { + return 0, io.EOF + } + n, err := r.R.Read(p) + r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 + if r.N > 0 && err == io.EOF { + return n, io.ErrUnexpectedEOF + } + if r.N <= 0 && err == nil { + return n, io.EOF + } + return n, err +} + +// readFromUntil reads from r into c.rawInput until c.rawInput contains +// at least n bytes or else returns an error. +func (c *Conn) readFromUntil(r io.Reader, n int) error { + if c.rawInput.Len() >= n { + return nil + } + needs := n - c.rawInput.Len() + // There might be extra input waiting on the wire. Make a best effort + // attempt to fetch it so that it can be used in (*Conn).Read to + // "predict" closeNotify alerts. + c.rawInput.Grow(needs + bytes.MinRead) + _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)}) + return err +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlertLocked(err alert) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr + } + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlert(err alert) error { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + c.extraConfig.AlternativeRecordLayer.SendAlert(uint8(err)) + return &net.OpError{Op: "local error", Err: err} + } + + c.out.Lock() + defer c.out.Unlock() + return c.sendAlertLocked(err) +} + +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the +// next application data record. There is the following trade-off: +// +// - For latency-sensitive applications, such as web browsing, each TLS +// record should fit in one TCP segment. +// - For throughput-sensitive applications, such as large file transfers, +// larger TLS records better amortize framing and encryption overheads. +// +// A simple heuristic that works well in practice is to use small records for +// the first 1MB of data, then use larger records for subsequent data, and +// reset back to smaller records after the connection becomes idle. See "High +// Performance Web Networking", Chapter 4, or: +// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ +// +// In the interests of simplicity and determinism, this code does not attempt +// to reset the record size once the connection is idle, however. +func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { + if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { + return maxPlaintext + } + + if c.bytesSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() + if c.out.cipher != nil { + switch ciph := c.out.cipher.(type) { + case cipher.Stream: + payloadBytes -= c.out.mac.Size() + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The MAC is appended before padding so affects the + // payload size directly. + payloadBytes -= c.out.mac.Size() + default: + panic("unknown cipher type") + } + } + if c.vers == VersionTLS13 { + payloadBytes-- // encrypted ContentType + } + + // Allow packet growth in arithmetic progression up to max. + pkt := c.packetsSent + c.packetsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} + +func (c *Conn) write(data []byte) (int, error) { + if c.buffering { + c.sendBuf = append(c.sendBuf, data...) + return len(data), nil + } + + n, err := c.conn.Write(data) + c.bytesSent += int64(n) + return n, err +} + +func (c *Conn) flush() (int, error) { + if len(c.sendBuf) == 0 { + return 0, nil + } + + n, err := c.conn.Write(c.sendBuf) + c.bytesSent += int64(n) + c.sendBuf = nil + c.buffering = false + return n, err +} + +// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. +var outBufPool = sync.Pool{ + New: func() interface{} { + return new([]byte) + }, +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + outBufPtr := outBufPool.Get().(*[]byte) + outBuf := *outBufPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *outBufPtr = outBuf + outBufPool.Put(outBufPtr) + }() + + var n int + for len(data) > 0 { + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { + m = maxPayload + } + + _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } else if vers == VersionTLS13 { + // TLS 1.3 froze the record layer version to 1.2. + // See RFC 8446, Section 5.1. + vers = VersionTLS12 + } + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) + + var err error + outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) + if err != nil { + return n, err + } + if _, err := c.write(outBuf); err != nil { + return n, err + } + n += m + data = data[m:] + } + + if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) + } + } + + return n, nil +} + +// writeRecord writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + if typ == recordTypeChangeCipherSpec { + return len(data), nil + } + return c.extraConfig.AlternativeRecordLayer.WriteRecord(data) + } + + c.out.Lock() + defer c.out.Unlock() + + return c.writeRecordLocked(typ, data) +} + +// readHandshake reads the next handshake message from +// the record layer. +func (c *Conn) readHandshake() (interface{}, error) { + var data []byte + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + var err error + data, err = c.extraConfig.AlternativeRecordLayer.ReadHandshakeMessage() + if err != nil { + return nil, err + } + } else { + for c.hand.Len() < 4 { + if err := c.readRecord(); err != nil { + return nil, err + } + } + + data = c.hand.Bytes() + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + c.sendAlertLocked(alertInternalError) + return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) + } + for c.hand.Len() < 4+n { + if err := c.readRecord(); err != nil { + return nil, err + } + } + data = c.hand.Next(4 + n) + } + var m handshakeMessage + switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) + case typeClientHello: + m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeNewSessionTicket: + if c.vers == VersionTLS13 { + m = new(newSessionTicketMsgTLS13) + } else { + m = new(newSessionTicketMsg) + } + case typeCertificate: + if c.vers == VersionTLS13 { + m = new(certificateMsgTLS13) + } else { + m = new(certificateMsg) + } + case typeCertificateRequest: + if c.vers == VersionTLS13 { + m = new(certificateRequestMsgTLS13) + } else { + m = &certificateRequestMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + } + case typeCertificateStatus: + m = new(certificateStatusMsg) + case typeServerKeyExchange: + m = new(serverKeyExchangeMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeCertificateVerify: + m = &certificateVerifyMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + case typeFinished: + m = new(finishedMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) + case typeKeyUpdate: + m = new(keyUpdateMsg) + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshalers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if !m.unmarshal(data) { + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + return m, nil +} + +var ( + errShutdown = errors.New("tls: protocol is shutdown") +) + +// Write writes data to the connection. +// +// As Write calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Write is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Write(b []byte) (int, error) { + // interlock with Close below + for { + x := atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return 0, net.ErrClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { + break + } + } + defer atomic.AddInt32(&c.activeCall, -2) + + if err := c.Handshake(); err != nil { + return 0, err + } + + c.out.Lock() + defer c.out.Unlock() + + if err := c.out.err; err != nil { + return 0, err + } + + if !c.handshakeComplete() { + return 0, alertInternalError + } + + if c.closeNotifySent { + return 0, errShutdown + } + + // TLS 1.0 is susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the IV. + // + // https://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // https://www.imperialviolet.org/2012/01/15/beastfollowup.html + + var m int + if len(b) > 1 && c.vers == VersionTLS10 { + if _, ok := c.out.cipher.(cipher.BlockMode); ok { + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.out.setErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecordLocked(recordTypeApplicationData, b) + return n + m, c.out.setErrorLocked(err) +} + +// handleRenegotiation processes a HelloRequest handshake message. +func (c *Conn) handleRenegotiation() error { + if c.vers == VersionTLS13 { + return errors.New("tls: internal error: unexpected renegotiation") + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + helloReq, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(helloReq, msg) + } + + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + atomic.StoreUint32(&c.handshakeStatus, 0) + if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + +func (c *Conn) HandlePostHandshakeMessage() error { + return c.handlePostHandshakeMessage() +} + +// handlePostHandshakeMessage processes a handshake message arrived after the +// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. +func (c *Conn) handlePostHandshakeMessage() error { + if c.vers != VersionTLS13 { + return c.handleRenegotiation() + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) + } + + switch msg := msg.(type) { + case *newSessionTicketMsgTLS13: + return c.handleNewSessionTicket(msg) + case *keyUpdateMsg: + return c.handleKeyUpdate(msg) + default: + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) + } +} + +func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil { + return c.in.setErrorLocked(c.sendAlert(alertInternalError)) + } + + newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) + c.in.setTrafficSecret(cipherSuite, newSecret) + + if keyUpdate.updateRequested { + c.out.Lock() + defer c.out.Unlock() + + msg := &keyUpdateMsg{} + _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + if err != nil { + // Surface the error at the next write. + c.out.setErrorLocked(err) + return nil + } + + newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) + c.out.setTrafficSecret(cipherSuite, newSecret) + } + + return nil +} + +// Read reads data from the connection. +// +// As Read calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Read is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Read(b []byte) (int, error) { + if err := c.Handshake(); err != nil { + return 0, err + } + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return 0, nil + } + + c.in.Lock() + defer c.in.Unlock() + + for c.input.Len() == 0 { + if err := c.readRecord(); err != nil { + return 0, err + } + for c.hand.Len() > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + return 0, err + } + } + } + + n, _ := c.input.Read(b) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && + recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { + if err := c.readRecord(); err != nil { + return n, err // will be io.EOF on closeNotify + } + } + + return n, nil +} + +// Close closes the connection. +func (c *Conn) Close() error { + // Interlock with Conn.Write above. + var x int32 + for { + x = atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return net.ErrClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { + break + } + } + if x != 0 { + // io.Writer and io.Closer should not be used concurrently. + // If Close is called while a Write is currently in-flight, + // interpret that as a sign that this Close is really just + // being used to break the Write and/or clean up resources and + // avoid sending the alertCloseNotify, which may block + // waiting on handshakeMutex or the c.out mutex. + return c.conn.Close() + } + + var alertErr error + if c.handshakeComplete() { + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } + } + + if err := c.conn.Close(); err != nil { + return err + } + return alertErr +} + +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + if !c.handshakeComplete() { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + c.out.Lock() + defer c.out.Unlock() + + if !c.closeNotifySent { + // Set a Write Deadline to prevent possibly blocking forever. + c.SetWriteDeadline(time.Now().Add(time.Second * 5)) + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + // Any subsequent writes will fail. + c.SetWriteDeadline(time.Now()) + } + return c.closeNotifyErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first Read or Write will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// the Dialer's DialContext method. +func (c *Conn) Handshake() error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeComplete() { + return nil + } + + c.in.Lock() + defer c.in.Unlock() + + c.handshakeErr = c.handshakeFn() + if c.handshakeErr == nil { + c.handshakes++ + } else { + // If an error occurred during the handshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + if c.handshakeErr == nil && !c.handshakeComplete() { + c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") + } + + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return c.connectionStateLocked() +} + +// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection. +func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return ConnectionStateWith0RTT{ + ConnectionState: c.connectionStateLocked(), + Used0RTT: c.used0RTT, + } +} + +func (c *Conn) connectionStateLocked() ConnectionState { + var state connectionState + state.HandshakeComplete = c.handshakeComplete() + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = true + state.ServerName = c.serverName + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } + } + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } + return toConnectionState(state) +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) VerifyHostname(host string) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.isClient { + return errors.New("tls: VerifyHostname called on TLS server connection") + } + if !c.handshakeComplete() { + return errors.New("tls: handshake has not yet been performed") + } + if len(c.verifiedChains) == 0 { + return errors.New("tls: handshake did not verify certificate chain") + } + return c.peerCertificates[0].VerifyHostname(host) +} + +func (c *Conn) handshakeComplete() bool { + return atomic.LoadUint32(&c.handshakeStatus) == 1 +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client.go new file mode 100644 index 00000000..a447061a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client.go @@ -0,0 +1,1105 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "strings" + "sync/atomic" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +const clientSessionStateVersion = 1 + +type clientHandshakeState struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + finishedHash finishedHash + masterSecret []byte + session *clientSessionState +} + +func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { + config := c.config + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + + nextProtosLength := 0 + for _, proto := range config.NextProtos { + if l := len(proto); l == 0 || l > 255 { + return nil, nil, errors.New("tls: invalid NextProtos value") + } else { + nextProtosLength += 1 + l + } + } + if nextProtosLength > 0xffff { + return nil, nil, errors.New("tls: NextProtos values too large") + } + + var supportedVersions []uint16 + var clientHelloVersion uint16 + if c.extraConfig.usesAlternativeRecordLayer() { + if config.maxSupportedVersion() < VersionTLS13 { + return nil, nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") + } + // Only offer TLS 1.3 when QUIC is used. + supportedVersions = []uint16{VersionTLS13} + clientHelloVersion = VersionTLS13 + } else { + supportedVersions = config.supportedVersions() + if len(supportedVersions) == 0 { + return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") + } + clientHelloVersion = config.maxSupportedVersion() + } + + // The version at the beginning of the ClientHello was capped at TLS 1.2 + // for compatibility reasons. The supported_versions extension is used + // to negotiate versions now. See RFC 8446, Section 4.2.1. + if clientHelloVersion > VersionTLS12 { + clientHelloVersion = VersionTLS12 + } + + hello := &clientHelloMsg{ + vers: clientHelloVersion, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + ocspStapling: true, + scts: true, + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + secureRenegotiationSupported: true, + alpnProtocols: config.NextProtos, + supportedVersions: supportedVersions, + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } + + possibleCipherSuites := config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) + + // add non-TLS 1.3 cipher suites + if c.config.MinVersion <= VersionTLS12 { + for _, suiteId := range possibleCipherSuites { + for _, suite := range cipherSuites { + if suite.id != suiteId { + continue + } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + break + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) + break + } + } + } + + _, err := io.ReadFull(config.rand(), hello.random) + if err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + // A random session ID is used to detect when the server accepted a ticket + // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as + // a compatibility measure (see RFC 8446, Section 4.1.2). + if c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + hello.sessionId = make([]byte, 32) + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + } + + if hello.vers >= VersionTLS12 { + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + var params ecdheParameters + if hello.supportedVersions[0] == VersionTLS13 { + var hasTLS13CipherSuite bool + // add TLS 1.3 cipher suites + for _, suiteID := range possibleCipherSuites { + for _, suite := range cipherSuitesTLS13 { + if suite.id == suiteID { + hasTLS13CipherSuite = true + hello.cipherSuites = append(hello.cipherSuites, suiteID) + } + } + } + if !hasTLS13CipherSuite { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13()...) + } + + curveID := config.curvePreferences()[0] + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err = generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + if hello.supportedVersions[0] == VersionTLS13 && c.extraConfig != nil && c.extraConfig.GetExtensions != nil { + hello.additionalExtensions = c.extraConfig.GetExtensions(typeClientHello) + } + + return hello, params, nil +} + +func (c *Conn) clientHandshake() (err error) { + if c.config == nil { + c.config = fromConfig(defaultConfig()) + } + c.setAlternativeRecordLayer() + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, ecdheParams, err := c.makeClientHello() + if err != nil { + return err + } + c.serverName = hello.serverName + + cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + if cacheKey != "" && session != nil { + var deletedTicket bool + if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT { + // don't reuse a session ticket that enabled 0-RTT + c.config.ClientSessionCache.Put(cacheKey, nil) + deletedTicket = true + + if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil { + h := suite.hash.New() + h.Write(hello.marshal()) + clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h) + c.out.exportKey(Encryption0RTT, suite, clientEarlySecret) + if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + } + } + if !deletedTicket { + defer func() { + // If we got a handshake failure when resuming a session, throw away + // the session ticket. See RFC 5077, Section 3.2. + // + // RFC 8446 makes no mention of dropping tickets on failure, but it + // does require servers to abort on invalid binders, so we need to + // delete tickets to recover from a corrupted PSK. + if err != nil { + c.config.ClientSessionCache.Put(cacheKey, nil) + } + }() + } + } + + if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + + if err := c.pickTLSVersion(serverHello); err != nil { + return err + } + + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion() + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + + if c.vers == VersionTLS13 { + hs := &clientHandshakeStateTLS13{ + c: c, + serverHello: serverHello, + hello: hello, + ecdheParams: ecdheParams, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + } + + // In TLS 1.3, session tickets are delivered after the handshake. + return hs.handshake() + } + + hs := &clientHandshakeState{ + c: c, + serverHello: serverHello, + hello: hello, + session: session, + } + + if err := hs.handshake(); err != nil { + return err + } + + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one. + if cacheKey != "" && hs.session != nil && session != hs.session { + c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session)) + } + + return nil +} + +// extract the app data saved in the session.nonce, +// and set the session.nonce to the actual nonce value +func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max early data */, []byte /* app data */, bool /* ok */) { + s := cryptobyte.String(session.nonce) + var version uint16 + if !s.ReadUint16(&version) { + return 0, nil, false + } + if version != clientSessionStateVersion { + return 0, nil, false + } + var maxEarlyData uint32 + if !s.ReadUint32(&maxEarlyData) { + return 0, nil, false + } + var appData []byte + if !readUint16LengthPrefixed(&s, &appData) { + return 0, nil, false + } + var nonce []byte + if !readUint16LengthPrefixed(&s, &nonce) { + return 0, nil, false + } + session.nonce = nonce + return maxEarlyData, appData, true +} + +func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, + session *clientSessionState, earlySecret, binderKey []byte) { + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return "", nil, nil, nil + } + + hello.ticketSupported = true + + if hello.supportedVersions[0] == VersionTLS13 { + // Require DHE on resumption as it guarantees forward secrecy against + // compromise of the session ticket key. See RFC 8446, Section 4.2.9. + hello.pskModes = []uint8{pskModeDHE} + } + + // Session resumption is not allowed if renegotiating because + // renegotiation is primarily used to allow a client to send a client + // certificate, which would be skipped if session resumption occurred. + if c.handshakes != 0 { + return "", nil, nil, nil + } + + // Try to resume a previously negotiated TLS session, if available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + sess, ok := c.config.ClientSessionCache.Get(cacheKey) + if !ok || sess == nil { + return cacheKey, nil, nil, nil + } + session = fromClientSessionState(sess) + + var appData []byte + var maxEarlyData uint32 + if session.vers == VersionTLS13 { + var ok bool + maxEarlyData, appData, ok = c.decodeSessionState(session) + if !ok { // delete it, if parsing failed + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + } + + // Check that version used for the previous session is still valid. + versOk := false + for _, v := range hello.supportedVersions { + if v == session.vers { + versOk = true + break + } + } + if !versOk { + return cacheKey, nil, nil, nil + } + + // Check that the cached server certificate is not expired, and that it's + // valid for the ServerName. This should be ensured by the cache key, but + // protect the application from a faulty ClientSessionCache implementation. + if !c.config.InsecureSkipVerify { + if len(session.verifiedChains) == 0 { + // The original connection had InsecureSkipVerify, while this doesn't. + return cacheKey, nil, nil, nil + } + serverCert := session.serverCertificates[0] + if c.config.time().After(serverCert.NotAfter) { + // Expired certificate, delete the entry. + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { + return cacheKey, nil, nil, nil + } + } + + if session.vers != VersionTLS13 { + // In TLS 1.2 the cipher suite must match the resumed session. Ensure we + // are still offering it. + if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { + return cacheKey, nil, nil, nil + } + + hello.sessionTicket = session.sessionTicket + return + } + + // Check that the session ticket is not expired. + if c.config.time().After(session.useBy) { + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + + // In TLS 1.3 the KDF hash must match the resumed session. Ensure we + // offer at least one cipher suite with that hash. + cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) + if cipherSuite == nil { + return cacheKey, nil, nil, nil + } + cipherSuiteOk := false + for _, offeredID := range hello.cipherSuites { + offeredSuite := cipherSuiteTLS13ByID(offeredID) + if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return cacheKey, nil, nil, nil + } + + // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. + ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) + identity := pskIdentity{ + label: session.sessionTicket, + obfuscatedTicketAge: ticketAge + session.ageAdd, + } + hello.pskIdentities = []pskIdentity{identity} + hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} + + // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. + psk := cipherSuite.expandLabel(session.masterSecret, "resumption", + session.nonce, cipherSuite.hash.Size()) + earlySecret = cipherSuite.extract(psk, nil) + binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) + if c.extraConfig != nil { + hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0 + } + transcript := cipherSuite.hash.New() + transcript.Write(hello.marshalWithoutBinders()) + pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} + hello.updateBinders(pskBinders) + + if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { + c.extraConfig.SetAppDataFromSessionState(appData) + } + return +} + +func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { + peerVersion := serverHello.vers + if serverHello.supportedVersion != 0 { + peerVersion = serverHello.supportedVersion + } + + vers, ok := c.config.mutualVersion([]uint16{peerVersion}) + if !ok { + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) + } + + c.vers = vers + c.haveVers = true + c.in.version = vers + c.out.version = vers + + return nil +} + +// Does the handshake, either a full one or resumes old session. Requires hs.c, +// hs.hello, hs.serverHello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() error { + c := hs.c + + isResume, err := hs.processServerHello() + if err != nil { + return err + } + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + + // No signatures of the handshake are needed in a resumption. + // Otherwise, in a full handshake, if we don't have any certificates + // configured then we will never send a CertificateVerify message and + // thus no signatures are needed in that case either. + if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { + hs.finishedHash.discardHandshakeBuffer() + } + + hs.finishedHash.Write(hs.hello.marshal()) + hs.finishedHash.Write(hs.serverHello.marshal()) + + c.buffering = true + c.didResume = isResume + if isResume { + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = false + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } else { + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = true + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + + hs.c.cipherSuite = hs.suite.id + return nil +} + +func (hs *clientHandshakeState) doFullHandshake() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + certMsg, ok := msg.(*certificateMsg) + if !ok || len(certMsg.certificates) == 0 { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + msg, err = c.readHandshake() + if err != nil { + return err + } + + cs, ok := msg.(*certificateStatusMsg) + if ok { + // RFC4366 on Certificate Status Request: + // The server MAY return a "certificate_status" message. + + if !hs.serverHello.ocspStapling { + // If a server returns a "CertificateStatus" message, then the + // server MUST have included an extension of type "status_request" + // with empty "extension_data" in the extended server hello. + + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received unexpected CertificateStatus message") + } + hs.finishedHash.Write(cs.marshal()) + + c.ocspResponse = cs.response + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + if c.handshakes == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := c.verifyServerCertificate(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + + keyAgreement := hs.suite.ka(c.vers) + + skx, ok := msg.(*serverKeyExchangeMsg) + if ok { + hs.finishedHash.Write(skx.marshal()) + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) + if err != nil { + c.sendAlert(alertUnexpectedMessage) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + var chainToSend *Certificate + var certRequested bool + certReq, ok := msg.(*certificateRequestMsg) + if ok { + certRequested = true + hs.finishedHash.Write(certReq.marshal()) + + cri := certificateRequestInfoFromMsg(c.vers, certReq) + if chainToSend, err = c.getClientCertificate(cri); err != nil { + c.sendAlert(alertInternalError) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + shd, ok := msg.(*serverHelloDoneMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) + } + hs.finishedHash.Write(shd.marshal()) + + // If the server requested a certificate then we have to send a + // Certificate message, even if it's empty because we don't have a + // certificate to send. + if certRequested { + certMsg = new(certificateMsg) + certMsg.certificates = chainToSend.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + } + + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ckx != nil { + hs.finishedHash.Write(ckx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + return err + } + } + + if chainToSend != nil && len(chainToSend.Certificate) > 0 { + certVerify := &certificateVerifyMsg{} + + key, ok := chainToSend.PrivateKey.(crypto.Signer) + if !ok { + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + certVerify.hasSignatureAlgorithm = true + certVerify.signatureAlgorithm = signatureAlgorithm + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.finishedHash.Write(certVerify.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + return err + } + } + + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to write to key log: " + err.Error()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher interface{} + var clientHash, serverHash hash.Hash + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) + return nil +} + +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} + +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if err := hs.pickCipherSuite(); err != nil { + return false, err + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") + } + + if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { + c.secureRenegotiation = true + if len(hs.serverHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: initial handshake had non-empty renegotiation extension") + } + } + + if c.handshakes > 0 && c.secureRenegotiation { + var expectedSecureRenegotiation [24]byte + copy(expectedSecureRenegotiation[:], c.clientFinished[:]) + copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) + if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: incorrect renegotiation extension contents") + } + } + + if hs.serverHello.alpnProtocol != "" { + if len(hs.hello.alpnProtocols) == 0 { + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: server advertised unrequested ALPN extension") + } + if mutualProtocol([]string{hs.serverHello.alpnProtocol}, hs.hello.alpnProtocols) == "" { + c.sendAlert(alertUnsupportedExtension) + return false, errors.New("tls: server selected unadvertised ALPN protocol") + } + c.clientProtocol = hs.serverHello.alpnProtocol + } + + c.scts = hs.serverHello.scts + + if !hs.serverResumedSession() { + return false, nil + } + + if hs.session.vers != c.vers { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different version") + } + + if hs.session.cipherSuite != hs.suite.id { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different cipher suite") + } + + // Restore masterSecret, peerCerts, and ocspResponse from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + // Let the ServerHello SCTs override the session SCTs from the original + // connection, if any are provided + if len(c.scts) == 0 && len(hs.session.scts) != 0 { + c.scts = hs.session.scts + } + + return true, nil +} + +func (hs *clientHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + + verify := hs.finishedHash.serverSum(hs.masterSecret) + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") + } + hs.finishedHash.Write(serverFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + if !hs.serverHello.ticketSupported { + return nil + } + + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + hs.finishedHash.Write(sessionTicketMsg.marshal()) + + hs.session = &clientSessionState{ + sessionTicket: sessionTicketMsg.ticket, + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + return nil +} + +func (hs *clientHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + copy(out, finished.verifyData) + return nil +} + +// verifyServerCertificate parses and verifies the provided chain, setting +// c.verifiedChains and c.peerCertificates or sending the appropriate alert. +func (c *Conn) verifyServerCertificate(certificates [][]byte) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + if !c.config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + var err error + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + switch certs[0].PublicKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: + break + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) + } + + c.peerCertificates = certs + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS +// <= 1.2 CertificateRequest, making an effort to fill in missing information. +func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { + cri := &certificateRequestInfo{ + AcceptableCAs: certReq.certificateAuthorities, + Version: vers, + } + + var rsaAvail, ecAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case certTypeRSASign: + rsaAvail = true + case certTypeECDSASign: + ecAvail = true + } + } + + if !certReq.hasSignatureAlgorithm { + // Prior to TLS 1.2, signature schemes did not exist. In this case we + // make up a list based on the acceptable certificate types, to help + // GetClientCertificate and SupportsCertificate select the right certificate. + // The hash part of the SignatureScheme is a lie here, because + // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. + switch { + case rsaAvail && ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case rsaAvail: + cri.SignatureSchemes = []SignatureScheme{ + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + } + } + return toCertificateRequestInfo(cri) + } + + // Filter the signature schemes based on the certificate types. + // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). + cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) + for _, sigScheme := range certReq.supportedSignatureAlgorithms { + sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) + if err != nil { + continue + } + switch sigType { + case signatureECDSA, signatureEd25519: + if ecAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + case signatureRSAPSS, signaturePKCS1v15: + if rsaAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + } + } + + return toCertificateRequestInfo(cri) +} + +func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { + if c.config.GetClientCertificate != nil { + return c.config.GetClientCertificate(cri) + } + + for _, chain := range c.config.Certificates { + if err := cri.SupportsCertificate(&chain); err != nil { + continue + } + return &chain, nil + } + + // No acceptable certificate found. Don't send a certificate. + return new(Certificate), nil +} + +const clientSessionCacheKeyPrefix = "qtls-" + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *config) string { + if len(config.ServerName) > 0 { + return clientSessionCacheKeyPrefix + config.ServerName + } + return clientSessionCacheKeyPrefix + serverAddr.String() +} + +// mutualProtocol finds the mutual ALPN protocol given list of possible +// protocols and a list of the preference order. +func mutualProtocol(protos, preferenceProtos []string) string { + for _, s := range preferenceProtos { + for _, c := range protos { + if s == c { + return s + } + } + } + return "" +} + +// hostnameInSNI converts name into an appropriate hostname for SNI. +// Literal IP addresses and absolute FQDNs are not permitted as SNI values. +// See RFC 6066, Section 3. +func hostnameInSNI(name string) string { + host := name + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } + if i := strings.LastIndex(host, "%"); i > 0 { + host = host[:i] + } + if net.ParseIP(host) != nil { + return "" + } + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return name +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client_tls13.go new file mode 100644 index 00000000..fb70cec0 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client_tls13.go @@ -0,0 +1,740 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto" + "crypto/hmac" + "crypto/rsa" + "encoding/binary" + "errors" + "fmt" + "hash" + "sync/atomic" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +type clientHandshakeStateTLS13 struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheParams ecdheParameters + + session *clientSessionState + earlySecret []byte + binderKey []byte + + certReq *certificateRequestMsgTLS13 + usingPSK bool + sentDummyCCS bool + suite *cipherSuiteTLS13 + transcript hash.Hash + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 +} + +// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. +func (hs *clientHandshakeStateTLS13) handshake() error { + c := hs.c + + // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, + // sections 4.1.2 and 4.1.3. + if c.handshakes > 0 { + c.sendAlert(alertProtocolVersion) + return errors.New("tls: server selected TLS 1.3 in a renegotiation") + } + + // Consistency check on the presence of a keyShare and its parameters. + if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + return c.sendAlert(alertInternalError) + } + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + hs.transcript = hs.suite.hash.New() + hs.transcript.Write(hs.hello.marshal()) + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.processHelloRetryRequest(); err != nil { + return err + } + } + + hs.transcript.Write(hs.serverHello.marshal()) + + c.buffering = true + if err := hs.processServerHello(); err != nil { + return err + } + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.establishHandshakeKeys(); err != nil { + return err + } + if err := hs.readServerParameters(); err != nil { + return err + } + if err := hs.readServerCertificate(); err != nil { + return err + } + if err := hs.readServerFinished(); err != nil { + return err + } + if err := hs.sendClientCertificate(); err != nil { + return err + } + if err := hs.sendClientFinished(); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// checkServerHelloOrHRR does validity checks that apply to both ServerHello and +// HelloRetryRequest messages. It sets hs.suite. +func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { + c := hs.c + + if hs.serverHello.supportedVersion == 0 { + c.sendAlert(alertMissingExtension) + return errors.New("tls: server selected TLS 1.3 using the legacy version field") + } + + if hs.serverHello.supportedVersion != VersionTLS13 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid version after a HelloRetryRequest") + } + + if hs.serverHello.vers != VersionTLS12 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an incorrect legacy version") + } + + if hs.serverHello.ocspStapling || + hs.serverHello.ticketSupported || + hs.serverHello.secureRenegotiationSupported || + len(hs.serverHello.secureRenegotiation) != 0 || + len(hs.serverHello.alpnProtocol) != 0 || + len(hs.serverHello.scts) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") + } + + if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not echo the legacy session ID") + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported compression format") + } + + selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) + if hs.suite != nil && selectedSuite != hs.suite { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server changed cipher suite after a HelloRetryRequest") + } + if selectedSuite == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server chose an unconfigured cipher suite") + } + hs.suite = selectedSuite + c.cipherSuite = hs.suite.id + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and +// resends hs.hello, and reads the new ServerHello into hs.serverHello. +func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. (The idea is that the server might offload transcript + // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + hs.transcript.Write(hs.serverHello.marshal()) + + // The only HelloRetryRequest extensions we support are key_share and + // cookie, and clients must abort the handshake if the HRR would not result + // in any change in the ClientHello. + if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest message") + } + + if hs.serverHello.cookie != nil { + hs.hello.cookie = hs.serverHello.cookie + } + + if hs.serverHello.serverShare.group != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received malformed key_share extension") + } + + // If the server sent a key_share extension selecting a group, ensure it's + // a group we advertised but did not send a key share for, and send a key + // share for it this time. + if curveID := hs.serverHello.selectedGroup; curveID != 0 { + curveOK := false + for _, id := range hs.hello.supportedCurves { + if id == curveID { + curveOK = true + break + } + } + if !curveOK { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + if hs.ecdheParams.CurveID() == curveID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), curveID) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.ecdheParams = params + hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + hs.hello.raw = nil + if len(hs.hello.pskIdentities) > 0 { + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash == hs.suite.hash { + // Update binders and obfuscated_ticket_age. + ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) + hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd + + transcript := hs.suite.hash.New() + transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + transcript.Write(chHash) + transcript.Write(hs.serverHello.marshal()) + transcript.Write(hs.hello.marshalWithoutBinders()) + pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} + hs.hello.updateBinders(pskBinders) + } else { + // Server selected a cipher suite incompatible with the PSK. + hs.hello.pskIdentities = nil + hs.hello.pskBinders = nil + } + } + + if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { + c.extraConfig.Rejected0RTT() + } + hs.hello.earlyData = false // disable 0-RTT + + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + hs.serverHello = serverHello + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) processServerHello() error { + c := hs.c + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: server sent two HelloRetryRequest messages") + } + + if len(hs.serverHello.cookie) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a cookie in a normal ServerHello") + } + + if hs.serverHello.selectedGroup != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: malformed key_share extension") + } + + if hs.serverHello.serverShare.group == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not send a key share") + } + if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + + if !hs.serverHello.selectedIdentityPresent { + return nil + } + + if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK") + } + + if len(hs.hello.pskIdentities) != 1 || hs.session == nil { + return c.sendAlert(alertInternalError) + } + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash != hs.suite.hash { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK and cipher suite pair") + } + + hs.usingPSK = true + c.didResume = true + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + c.scts = hs.session.scts + return nil +} + +func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { + c := hs.c + + sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) + if sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + + earlySecret := hs.earlySecret + if !hs.usingPSK { + earlySecret = hs.suite.extract(nil, nil) + } + handshakeSecret := hs.suite.extract(sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret) + c.out.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(handshakeSecret, "derived", nil)) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerParameters() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(encryptedExtensions, msg) + } + // Notify the caller if 0-RTT was rejected. + if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { + c.extraConfig.Rejected0RTT() + } + c.used0RTT = encryptedExtensions.earlyData + if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil { + hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions) + } + hs.transcript.Write(encryptedExtensions.marshal()) + + if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection { + if len(encryptedExtensions.alpnProtocol) == 0 { + // the server didn't select an ALPN + c.sendAlert(alertNoApplicationProtocol) + return errors.New("ALPN negotiation failed. Server didn't offer any protocols") + } + if mutualProtocol([]string{encryptedExtensions.alpnProtocol}, hs.c.config.NextProtos) == "" { + // the protocol selected by the server was not offered + c.sendAlert(alertNoApplicationProtocol) + return fmt.Errorf("ALPN negotiation failed. Server offered: %q", encryptedExtensions.alpnProtocol) + } + } + if encryptedExtensions.alpnProtocol != "" { + if len(hs.hello.alpnProtocols) == 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server advertised unrequested ALPN extension") + } + if mutualProtocol([]string{encryptedExtensions.alpnProtocol}, hs.hello.alpnProtocols) == "" { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server selected unadvertised ALPN protocol") + } + c.clientProtocol = encryptedExtensions.alpnProtocol + } + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerCertificate() error { + c := hs.c + + // Either a PSK or a certificate is always used, but not both. + // See RFC 8446, Section 4.1.1. + if hs.usingPSK { + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certReq, ok := msg.(*certificateRequestMsgTLS13) + if ok { + hs.transcript.Write(certReq.marshal()) + + hs.certReq = certReq + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + if len(certMsg.certificate.Certificate) == 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received empty certificates message") + } + hs.transcript.Write(certMsg.marshal()) + + c.scts = certMsg.certificate.SignedCertificateTimestamps + c.ocspResponse = certMsg.certificate.OCSPStaple + + if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + if !hmac.Equal(expectedMAC, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid server finished hash") + } + + hs.transcript.Write(finished.marshal()) + + // Derive secrets that take context through the server Finished. + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionApplication, hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { + c := hs.c + + if hs.certReq == nil { + return nil + } + + cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{ + AcceptableCAs: hs.certReq.certificateAuthorities, + SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, + Version: c.vers, + })) + if err != nil { + return err + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *cert + certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + // If we sent an empty certificate message, skip the CertificateVerify. + if len(cert.Certificate) == 0 { + return nil + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + + certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) + if err != nil { + // getClientCertificate returned a certificate incompatible with the + // CertificateRequestInfo supported signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) + c.out.setTrafficSecret(hs.suite, hs.trafficSecret) + + if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + } + + return nil +} + +func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { + if !c.isClient { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received new session ticket from a client") + } + + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return nil + } + + // See RFC 8446, Section 4.6.1. + if msg.lifetime == 0 { + return nil + } + lifetime := time.Duration(msg.lifetime) * time.Second + if lifetime > maxSessionTicketLifetime { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: received a session ticket with invalid lifetime") + } + + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil || c.resumptionSecret == nil { + return c.sendAlert(alertInternalError) + } + + // We need to save the max_early_data_size that the server sent us, in order + // to decide if we're going to try 0-RTT with this ticket. + // However, at the same time, the qtls.ClientSessionTicket needs to be equal to + // the tls.ClientSessionTicket, so we can't just add a new field to the struct. + // We therefore abuse the nonce field (which is a byte slice) + nonceWithEarlyData := make([]byte, len(msg.nonce)+4) + binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData) + copy(nonceWithEarlyData[4:], msg.nonce) + + var appData []byte + if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil { + appData = c.extraConfig.GetAppDataForSessionState() + } + var b cryptobyte.Builder + b.AddUint16(clientSessionStateVersion) // revision + b.AddUint32(msg.maxEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(appData) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(msg.nonce) + }) + + // Save the resumption_master_secret and nonce instead of deriving the PSK + // to do the least amount of work on NewSessionTicket messages before we + // know if the ticket will be used. Forward secrecy of resumed connections + // is guaranteed by the requirement for pskModeDHE. + session := &clientSessionState{ + sessionTicket: msg.label, + vers: c.vers, + cipherSuite: c.cipherSuite, + masterSecret: c.resumptionSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + nonce: b.BytesOrPanic(), + useBy: c.config.time().Add(lifetime), + ageAdd: msg.ageAdd, + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session)) + + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_messages.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_messages.go new file mode 100644 index 00000000..1ab75762 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_messages.go @@ -0,0 +1,1832 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "fmt" + "strings" + + "golang.org/x/crypto/cryptobyte" +) + +// The marshalingFunction type is an adapter to allow the use of ordinary +// functions as cryptobyte.MarshalingValue. +type marshalingFunction func(b *cryptobyte.Builder) error + +func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { + return f(b) +} + +// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If +// the length of the sequence is not the value specified, it produces an error. +func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { + b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { + if len(v) != n { + return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) + } + b.AddBytes(v) + return nil + })) +} + +// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. +func addUint64(b *cryptobyte.Builder, v uint64) { + b.AddUint32(uint32(v >> 32)) + b.AddUint32(uint32(v)) +} + +// readUint64 decodes a big-endian, 64-bit value into out and advances over it. +// It reports whether the read was successful. +func readUint64(s *cryptobyte.String, out *uint64) bool { + var hi, lo uint32 + if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { + return false + } + *out = uint64(hi)<<32 | uint64(lo) + return true +} + +// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) +} + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + serverName string + ocspStapling bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocols []string + scts bool + supportedVersions []uint16 + cookie []byte + keyShares []keyShare + earlyData bool + pskModes []uint8 + pskIdentities []pskIdentity + pskBinders [][]byte + additionalExtensions []Extension +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeClientHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + b.AddUint16(suite) + } + }) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.compressionMethods) + }) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + b.AddUint16(extensionServerName) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // name_type = host_name + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(1) // status_type = ocsp + b.AddUint16(0) // empty responder_id_list + b.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + b.AddUint16(extensionSupportedCurves) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + b.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + b.AddUint16(extensionSessionTicket) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + b.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ks := range m.keyShares { + b.AddUint16(uint16(ks.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + b.AddUint16(extensionPSKModes) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.pskModes) + }) + }) + } + for _, ext := range m.additionalExtensions { + b.AddUint16(ext.Type) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ext.Data) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(psk.label) + }) + b.AddUint32(psk.obfuscatedTicketAge) + } + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +// marshalWithoutBinders returns the ClientHello through the +// PreSharedKeyExtension.identities field, according to RFC 8446, Section +// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. +func (m *clientHelloMsg) marshalWithoutBinders() []byte { + bindersLen := 2 // uint16 length prefix + for _, binder := range m.pskBinders { + bindersLen += 1 // uint8 length prefix + bindersLen += len(binder) + } + + fullMessage := m.marshal() + return fullMessage[:len(fullMessage)-bindersLen] +} + +// updateBinders updates the m.pskBinders field, if necessary updating the +// cached marshaled representation. The supplied binders must have the same +// length as the current m.pskBinders. +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { + if len(pskBinders) != len(m.pskBinders) { + panic("tls: internal error: pskBinders length mismatch") + } + for i := range m.pskBinders { + if len(pskBinders[i]) != len(m.pskBinders[i]) { + panic("tls: internal error: pskBinders length mismatch") + } + } + m.pskBinders = pskBinders + if m.raw != nil { + lenWithoutBinders := len(m.marshalWithoutBinders()) + // TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported. + b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders]) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + if len(b.BytesOrPanic()) != len(m.raw) { + panic("tls: internal error: failed to update binders") + } + } +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + *m = clientHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) { + return false + } + + var cipherSuites cryptobyte.String + if !s.ReadUint16LengthPrefixed(&cipherSuites) { + return false + } + m.cipherSuites = []uint16{} + m.secureRenegotiationSupported = false + for !cipherSuites.Empty() { + var suite uint16 + if !cipherSuites.ReadUint16(&suite) { + return false + } + if suite == scsvRenegotiation { + m.secureRenegotiationSupported = true + } + m.cipherSuites = append(m.cipherSuites, suite) + } + + if !readUint8LengthPrefixed(&s, &m.compressionMethods) { + return false + } + + if s.Empty() { + // ClientHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var ext uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&ext) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch ext { + case extensionServerName: + // RFC 6066, Section 3 + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { + return false + } + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { + return false + } + if nameType != 0 { + continue + } + if len(m.serverName) != 0 { + // Multiple names of the same name_type are prohibited. + return false + } + m.serverName = string(serverName) + // An SNI value may not include a trailing dot. + if strings.HasSuffix(m.serverName, ".") { + return false + } + } + case extensionStatusRequest: + // RFC 4366, Section 3.6 + var statusType uint8 + var ignored cryptobyte.String + if !extData.ReadUint8(&statusType) || + !extData.ReadUint16LengthPrefixed(&ignored) || + !extData.ReadUint16LengthPrefixed(&ignored) { + return false + } + m.ocspStapling = statusType == statusTypeOCSP + case extensionSupportedCurves: + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + var curves cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { + return false + } + for !curves.Empty() { + var curve uint16 + if !curves.ReadUint16(&curve) { + return false + } + m.supportedCurves = append(m.supportedCurves, CurveID(curve)) + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + case extensionSessionTicket: + // RFC 5077, Section 3.2 + m.ticketSupported = true + extData.ReadBytes(&m.sessionTicket, len(extData)) + case extensionSignatureAlgorithms: + // RFC 5246, Section 7.4.1.4.1 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + // RFC 8446, Section 4.2.3 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionRenegotiationInfo: + // RFC 5746, Section 3.2 + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + // RFC 7301, Section 3.1 + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + for !protoList.Empty() { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(proto)) + } + case extensionSCT: + // RFC 6962, Section 3.3.1 + m.scts = true + case extensionSupportedVersions: + // RFC 8446, Section 4.2.1 + var versList cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { + return false + } + for !versList.Empty() { + var vers uint16 + if !versList.ReadUint16(&vers) { + return false + } + m.supportedVersions = append(m.supportedVersions, vers) + } + case extensionCookie: + // RFC 8446, Section 4.2.2 + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // RFC 8446, Section 4.2.8 + var clientShares cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&clientShares) { + return false + } + for !clientShares.Empty() { + var ks keyShare + if !clientShares.ReadUint16((*uint16)(&ks.group)) || + !readUint16LengthPrefixed(&clientShares, &ks.data) || + len(ks.data) == 0 { + return false + } + m.keyShares = append(m.keyShares, ks) + } + case extensionEarlyData: + // RFC 8446, Section 4.2.10 + m.earlyData = true + case extensionPSKModes: + // RFC 8446, Section 4.2.9 + if !readUint8LengthPrefixed(&extData, &m.pskModes) { + return false + } + case extensionPreSharedKey: + // RFC 8446, Section 4.2.11 + if !extensions.Empty() { + return false // pre_shared_key must be the last extension + } + var identities cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { + return false + } + for !identities.Empty() { + var psk pskIdentity + if !readUint16LengthPrefixed(&identities, &psk.label) || + !identities.ReadUint32(&psk.obfuscatedTicketAge) || + len(psk.label) == 0 { + return false + } + m.pskIdentities = append(m.pskIdentities, psk) + } + var binders cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { + return false + } + for !binders.Empty() { + var binder []byte + if !readUint8LengthPrefixed(&binders, &binder) || + len(binder) == 0 { + return false + } + m.pskBinders = append(m.pskBinders, binder) + } + default: + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + ocspStapling bool + ticketSupported bool + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocol string + scts [][]byte + supportedVersion uint16 + serverShare keyShare + selectedIdentityPresent bool + selectedIdentity uint16 + supportedPoints []uint8 + + // HelloRetryRequest extensions + cookie []byte + selectedGroup CurveID +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeServerHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16(m.cipherSuite) + b.AddUint8(m.compressionMethod) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + b.AddUint16(extensionSessionTicket) + b.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range m.scts { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.serverShare.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *serverHelloMsg) unmarshal(data []byte) bool { + *m = serverHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) || + !s.ReadUint16(&m.cipherSuite) || + !s.ReadUint8(&m.compressionMethod) { + return false + } + + if s.Empty() { + // ServerHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSessionTicket: + m.ticketSupported = true + case extensionRenegotiationInfo: + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + m.scts = append(m.scts, sct) + } + case extensionSupportedVersions: + if !extData.ReadUint16(&m.supportedVersion) { + return false + } + case extensionCookie: + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // This extension has different formats in SH and HRR, accept either + // and let the handshake logic decide. See RFC 8446, Section 4.2.8. + if len(extData) == 2 { + if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { + return false + } + } else { + if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || + !readUint16LengthPrefixed(&extData, &m.serverShare.data) { + return false + } + } + case extensionPreSharedKey: + m.selectedIdentityPresent = true + if !extData.ReadUint16(&m.selectedIdentity) { + return false + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type encryptedExtensionsMsg struct { + raw []byte + alpnProtocol string + earlyData bool + + additionalExtensions []Extension +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeEncryptedExtensions) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + for _, ext := range m.additionalExtensions { + b.AddUint16(ext.Type) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ext.Data) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + *m = encryptedExtensionsMsg{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var ext uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&ext) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch ext { + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionEarlyData: + m.earlyData = true + default: + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type endOfEarlyDataMsg struct{} + +func (m *endOfEarlyDataMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeEndOfEarlyData + return x +} + +func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type keyUpdateMsg struct { + raw []byte + updateRequested bool +} + +func (m *keyUpdateMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeKeyUpdate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.updateRequested { + b.AddUint8(1) + } else { + b.AddUint8(0) + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *keyUpdateMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var updateRequested uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&updateRequested) || !s.Empty() { + return false + } + switch updateRequested { + case 0: + m.updateRequested = false + case 1: + m.updateRequested = true + default: + return false + } + return true +} + +type newSessionTicketMsgTLS13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + nonce []byte + label []byte + maxEarlyData uint32 +} + +func (m *newSessionTicketMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeNewSessionTicket) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.lifetime) + b.AddUint32(m.ageAdd) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.nonce) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.label) + }) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.maxEarlyData > 0 { + b.AddUint16(extensionEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.maxEarlyData) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { + *m = newSessionTicketMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint32(&m.lifetime) || + !s.ReadUint32(&m.ageAdd) || + !readUint8LengthPrefixed(&s, &m.nonce) || + !readUint16LengthPrefixed(&s, &m.label) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionEarlyData: + if !extData.ReadUint32(&m.maxEarlyData) { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateRequestMsgTLS13 struct { + raw []byte + ocspStapling bool + scts bool + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateRequest) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + // certificate_request_context (SHALL be zero length unless used for + // post-handshake authentication) + b.AddUint8(0) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.scts { + // RFC 8446, Section 4.4.2.1 makes no mention of + // signed_certificate_timestamp in CertificateRequest, but + // "Extensions in the Certificate message from the client MUST + // correspond to extensions in the CertificateRequest message + // from the server." and it appears in the table in Section 4.2. + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedSignatureAlgorithms) > 0 { + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.certificateAuthorities) > 0 { + b.AddUint16(extensionCertificateAuthorities) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ca) + }) + } + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { + *m = certificateRequestMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context, extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSCT: + m.scts = true + case extensionSignatureAlgorithms: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionCertificateAuthorities: + var auths cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { + return false + } + for !auths.Empty() { + var ca []byte + if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { + return false + } + m.certificateAuthorities = append(m.certificateAuthorities, ca) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return true +} + +type certificateMsgTLS13 struct { + raw []byte + certificate Certificate + ocspStapling bool + scts bool +} + +func (m *certificateMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // certificate_request_context + + certificate := m.certificate + if !m.ocspStapling { + certificate.OCSPStaple = nil + } + if !m.scts { + certificate.SignedCertificateTimestamps = nil + } + marshalCertificate(b, certificate) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for i, cert := range certificate.Certificate { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if i > 0 { + // This library only supports OCSP and SCT for leaf certificates. + return + } + if certificate.OCSPStaple != nil { + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(certificate.OCSPStaple) + }) + }) + } + if certificate.SignedCertificateTimestamps != nil { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range certificate.SignedCertificateTimestamps { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + }) + } + }) +} + +func (m *certificateMsgTLS13) unmarshal(data []byte) bool { + *m = certificateMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !unmarshalCertificate(&s, &m.certificate) || + !s.Empty() { + return false + } + + m.scts = m.certificate.SignedCertificateTimestamps != nil + m.ocspStapling = m.certificate.OCSPStaple != nil + + return true +} + +func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + var extensions cryptobyte.String + if !readUint24LengthPrefixed(&certList, &cert) || + !certList.ReadUint16LengthPrefixed(&extensions) { + return false + } + certificate.Certificate = append(certificate.Certificate, cert) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + if len(certificate.Certificate) > 1 { + // This library only supports OCSP and SCT for leaf certificates. + continue + } + + switch extension { + case extensionStatusRequest: + var statusType uint8 + if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || + len(certificate.OCSPStaple) == 0 { + return false + } + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + certificate.SignedCertificateTimestamps = append( + certificate.SignedCertificateTimestamps, sct) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + } + return true +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.key = data[4:] + return true +} + +type certificateStatusMsg struct { + raw []byte + response []byte +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateStatus) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.response) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var statusType uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&s, &m.response) || + len(m.response) == 0 || !s.Empty() { + return false + } + return true +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return false + } + m.ciphertext = data[4:] + return true +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeFinished) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.verifyData) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + return s.Skip(1) && + readUint24LengthPrefixed(&s, &m.verifyData) && + s.Empty() +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAlgorithm indicates whether this message includes a list of + // supported signature algorithms. This change was introduced with TLS 1.2. + hasSignatureAlgorithm bool + + certificateTypes []byte + supportedSignatureAlgorithms []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 4346, Section 7.4.4. + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAlgorithm { + length += 2 + 2*len(m.supportedSignatureAlgorithms) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAlgorithm { + n := len(m.supportedSignatureAlgorithms) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAlgo := range m.supportedSignatureAlgorithms { + y[0] = uint8(sigAlgo >> 8) + y[1] = uint8(sigAlgo) + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return false + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return false + } + + data = data[numCertTypes:] + + if m.hasSignatureAlgorithm { + if len(data) < 2 { + return false + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return false + } + if len(data) < int(sigAndHashLen) { + return false + } + numSigAlgos := sigAndHashLen / 2 + m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) + for i := range m.supportedSignatureAlgorithms { + m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) + data = data[2:] + } + } + + if len(data) < 2 { + return false + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return false + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return false + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + + return len(data) == 0 +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAlgorithm bool // format change introduced in TLS 1.2 + signatureAlgorithm SignatureScheme + signature []byte +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateVerify) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + b.AddUint16(uint16(m.signatureAlgorithm)) + } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.signature) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + if !s.Skip(4) { // message type and uint24 length field + return false + } + if m.hasSignatureAlgorithm { + if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { + return false + } + } + return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 5077, Section 3.3. + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 10 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return false + } + + m.ticket = data[10:] + + return true +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server.go new file mode 100644 index 00000000..5d39cc8f --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server.go @@ -0,0 +1,878 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "sync/atomic" + "time" +) + +// serverHandshakeState contains details of a server handshake in progress. +// It's discarded once the handshake has completed. +type serverHandshakeState struct { + c *Conn + clientHello *clientHelloMsg + hello *serverHelloMsg + suite *cipherSuite + ecdheOk bool + ecSignOk bool + rsaDecryptOk bool + rsaSignOk bool + sessionState *sessionState + finishedHash finishedHash + masterSecret []byte + cert *Certificate +} + +// serverHandshake performs a TLS handshake as a server. +func (c *Conn) serverHandshake() error { + c.setAlternativeRecordLayer() + + clientHello, err := c.readClientHello() + if err != nil { + return err + } + + if c.vers == VersionTLS13 { + hs := serverHandshakeStateTLS13{ + c: c, + clientHello: clientHello, + } + return hs.handshake() + } else if c.extraConfig.usesAlternativeRecordLayer() { + // This should already have been caught by the check that the ClientHello doesn't + // offer any (supported) versions older than TLS 1.3. + // Check again to make sure we can't be tricked into using an older version. + c.sendAlert(alertProtocolVersion) + return errors.New("tls: negotiated TLS < 1.3 when using QUIC") + } + + hs := serverHandshakeState{ + c: c, + clientHello: clientHello, + } + return hs.handshake() +} + +func (hs *serverHandshakeState) handshake() error { + c := hs.c + + if err := hs.processClientHello(); err != nil { + return err + } + + // For an overview of TLS handshaking, see RFC 5246, Section 7.3. + c.buffering = true + if hs.checkForResumption() { + // The client has included a session ticket and so we do an abbreviated handshake. + c.didResume = true + if err := hs.doResumeHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(c.serverFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = false + if err := hs.readFinished(nil); err != nil { + return err + } + } else { + // The client didn't include a session ticket, or it wasn't + // valid so we do a full handshake. + if err := hs.pickCipherSuite(); err != nil { + return err + } + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readFinished(c.clientFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = true + c.buffering = true + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(nil); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// readClientHello reads a ClientHello message and selects the protocol version. +func (c *Conn) readClientHello() (*clientHelloMsg, error) { + msg, err := c.readHandshake() + if err != nil { + return nil, err + } + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return nil, unexpectedMessageError(clientHello, msg) + } + + var configForClient *config + originalConfig := c.config + if c.config.GetConfigForClient != nil { + chi := newClientHelloInfo(c, clientHello) + if cfc, err := c.config.GetConfigForClient(chi); err != nil { + c.sendAlert(alertInternalError) + return nil, err + } else if cfc != nil { + configForClient = fromConfig(cfc) + c.config = configForClient + } + } + c.ticketKeys = originalConfig.ticketKeys(configForClient) + + clientVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + clientVersions = supportedVersionsFromMax(clientHello.vers) + } + if c.extraConfig.usesAlternativeRecordLayer() { + // In QUIC, the client MUST NOT offer any old TLS versions. + // Here, we can only check that none of the other supported versions of this library + // (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here. + for _, ver := range clientVersions { + if ver == VersionTLS13 { + continue + } + for _, v := range supportedVersions { + if ver == v { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver) + } + } + } + // Make the config we're using allows us to use TLS 1.3. + if c.config.maxSupportedVersion() < VersionTLS13 { + c.sendAlert(alertInternalError) + return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") + } + } + c.vers, ok = c.config.mutualVersion(clientVersions) + if !ok { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) + } + c.haveVers = true + c.in.version = c.vers + c.out.version = c.vers + + return clientHello, nil +} + +func (hs *serverHandshakeState) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers + + foundCompression := false + // We only support null compression, so check that the client offered it. + for _, compression := range hs.clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true + break + } + } + + if !foundCompression { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client does not support uncompressed connections") + } + + hs.hello.random = make([]byte, 32) + serverRandom := hs.hello.random + // Downgrade protection canaries. See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion() + if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { + if c.vers == VersionTLS12 { + copy(serverRandom[24:], downgradeCanaryTLS12) + } else { + copy(serverRandom[24:], downgradeCanaryTLS11) + } + serverRandom = serverRandom[:24] + } + _, err := io.ReadFull(c.config.rand(), serverRandom) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported + hs.hello.compressionMethod = compressionNone + if len(hs.clientHello.serverName) > 0 { + c.serverName = hs.clientHello.serverName + } + + if len(hs.clientHello.alpnProtocols) > 0 { + if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" { + hs.hello.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + } + } + + hs.cert, err = c.config.getCertificate(newClientHelloInfo(c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + if hs.clientHello.scts { + hs.hello.scts = hs.cert.SignedCertificateTimestamps + } + + hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) + + if hs.ecdheOk { + // Although omitting the ec_point_formats extension is permitted, some + // old OpenSSL version will refuse to handshake if not present. + // + // Per RFC 4492, section 5.1.2, implementations MUST support the + // uncompressed point format. See golang.org/issue/31943. + hs.hello.supportedPoints = []uint8{pointFormatUncompressed} + } + + if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { + switch priv.Public().(type) { + case *ecdsa.PublicKey: + hs.ecSignOk = true + case ed25519.PublicKey: + hs.ecSignOk = true + case *rsa.PublicKey: + hs.rsaSignOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) + } + } + if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { + switch priv.Public().(type) { + case *rsa.PublicKey: + hs.rsaDecryptOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) + } + } + + return nil +} + +// supportsECDHE returns whether ECDHE key exchanges can be used with this +// pre-TLS 1.3 client. +func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool { + supportsCurve := false + for _, curve := range supportedCurves { + if c.supportsCurve(curve) { + supportsCurve = true + break + } + } + + supportsPointFormat := false + for _, pointFormat := range supportedPoints { + if pointFormat == pointFormatUncompressed { + supportsPointFormat = true + break + } + } + + return supportsCurve && supportsPointFormat +} + +func (hs *serverHandshakeState) pickCipherSuite() error { + c := hs.c + + var preferenceList, supportedList []uint16 + if c.config.PreferServerCipherSuites { + preferenceList = c.config.cipherSuites() + supportedList = hs.clientHello.cipherSuites + + // If the client does not seem to have hardware support for AES-GCM, + // and the application did not specify a cipher suite preference order, + // prefer other AEAD ciphers even if we prioritized AES-GCM ciphers + // by default. + if c.config.CipherSuites == nil && !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = deprioritizeAES(preferenceList) + } + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = c.config.cipherSuites() + + // If we don't have hardware support for AES-GCM, prefer other AEAD + // ciphers even if the client prioritized AES-GCM. + if !hasAESGCMHardwareSupport { + preferenceList = deprioritizeAES(preferenceList) + } + } + + hs.suite = selectCipherSuite(preferenceList, supportedList, hs.cipherSuiteOk) + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // The client is doing a fallback connection. See RFC 7507. + if hs.clientHello.vers < c.config.maxSupportedVersion() { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + return nil +} + +func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + if !hs.ecdheOk { + return false + } + if c.flags&suiteECSign != 0 { + if !hs.ecSignOk { + return false + } + } else if !hs.rsaSignOk { + return false + } + } else if !hs.rsaDecryptOk { + return false + } + if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true +} + +// checkForResumption reports whether we should perform resumption on this connection. +func (hs *serverHandshakeState) checkForResumption() bool { + c := hs.c + + if c.config.SessionTicketsDisabled { + return false + } + + plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) + if plaintext == nil { + return false + } + hs.sessionState = &sessionState{usedOldKey: usedOldKey} + ok := hs.sessionState.unmarshal(plaintext) + if !ok { + return false + } + + createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + return false + } + + // Never resume a session for a different TLS version. + if c.vers != hs.sessionState.vers { + return false + } + + cipherSuiteOk := false + // Check that the client is still offering the ciphersuite in the session. + for _, id := range hs.clientHello.cipherSuites { + if id == hs.sessionState.cipherSuite { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return false + } + + // Check that we also support the ciphersuite from the session. + hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, + c.config.cipherSuites(), hs.cipherSuiteOk) + if hs.suite == nil { + return false + } + + sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + return false + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + return false + } + + return true +} + +func (hs *serverHandshakeState) doResumeHandshake() error { + c := hs.c + + hs.hello.cipherSuite = hs.suite.id + c.cipherSuite = hs.suite.id + // We echo the client's session ID in the ServerHello to let it know + // that we're doing a resumption. + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.ticketSupported = hs.sessionState.usedOldKey + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := c.processCertsFromClient(Certificate{ + Certificate: hs.sessionState.certificates, + }); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + hs.masterSecret = hs.sessionState.masterSecret + + return nil +} + +func (hs *serverHandshakeState) doFullHandshake() error { + c := hs.c + + if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { + hs.hello.ocspStapling = true + } + + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled + hs.hello.cipherSuite = hs.suite.id + + hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) + if c.config.ClientAuth == NoClientCert { + // No need to keep a full record of the handshake if client + // certificates won't be used. + hs.finishedHash.discardHandshakeBuffer() + } + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + certMsg := new(certificateMsg) + certMsg.certificates = hs.cert.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + if hs.hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.response = hs.cert.OCSPStaple + hs.finishedHash.Write(certStatus.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + return err + } + } + + keyAgreement := hs.suite.ka(c.vers) + skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if skx != nil { + hs.finishedHash.Write(skx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + return err + } + } + + var certReq *certificateRequestMsg + if c.config.ClientAuth >= RequestClientCert { + // Request a client certificate + certReq = new(certificateRequestMsg) + certReq.certificateTypes = []byte{ + byte(certTypeRSASign), + byte(certTypeECDSASign), + } + if c.vers >= VersionTLS12 { + certReq.hasSignatureAlgorithm = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + hs.finishedHash.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + helloDone := new(serverHelloDoneMsg) + hs.finishedHash.Write(helloDone.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + return err + } + + if _, err := c.flush(); err != nil { + return err + } + + var pub crypto.PublicKey // public key for client auth, if any + + msg, err := c.readHandshake() + if err != nil { + return err + } + + // If we requested a client certificate, then the client must send a + // certificate message, even if it's empty. + if c.config.ClientAuth >= RequestClientCert { + certMsg, ok := msg.(*certificateMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(Certificate{ + Certificate: certMsg.certificates, + }); err != nil { + return err + } + if len(certMsg.certificates) != 0 { + pub = c.peerCertificates[0].PublicKey + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + // Get client key exchange + ckx, ok := msg.(*clientKeyExchangeMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) + } + hs.finishedHash.Write(ckx.marshal()) + + preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + + // If we received a client cert in response to our certificate request message, + // the client will send us a certificateVerifyMsg immediately after the + // clientKeyExchangeMsg. This message is a digest of all preceding + // handshake-layer messages that is signed using the private key corresponding + // to the client's certificate. This allows us to verify that the client is in + // possession of the private key of the certificate. + if len(c.peerCertificates) > 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.finishedHash.Write(certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *serverHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + + var clientCipher, serverCipher interface{} + var clientHash, serverHash hash.Hash + + if hs.suite.aead == nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) + c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) + + return nil +} + +func (hs *serverHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + + verify := hs.finishedHash.clientSum(hs.masterSecret) + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client's Finished message is incorrect") + } + + hs.finishedHash.Write(clientFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket() error { + // ticketSupported is set in a resumption handshake if the + // ticket from the client was encrypted with an old session + // ticket key and thus a refreshed ticket should be sent. + if !hs.hello.ticketSupported { + return nil + } + + c := hs.c + m := new(newSessionTicketMsg) + + createdAt := uint64(c.config.time().Unix()) + if hs.sessionState != nil { + // If this is re-wrapping an old key, then keep + // the original time it was created. + createdAt = hs.sessionState.createdAt + } + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + createdAt: createdAt, + masterSecret: hs.masterSecret, + certificates: certsFromClient, + } + var err error + m.ticket, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + + hs.finishedHash.Write(m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + copy(out, finished.verifyData) + + return nil +} + +// processCertsFromClient takes a chain of client certificates either from a +// Certificates message or from a sessionState and verifies them. It returns +// the public key of the leaf certificate. +func (c *Conn) processCertsFromClient(certificate Certificate) error { + certificates := certificate.Certificate + certs := make([]*x509.Certificate, len(certificates)) + var err error + for i, asn1Data := range certificates { + if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse client certificate: " + err.Error()) + } + } + + if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: client didn't provide a certificate") + } + + if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { + opts := x509.VerifyOptions{ + Roots: c.config.ClientCAs, + CurrentTime: c.config.time(), + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + chains, err := certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to verify client certificate: " + err.Error()) + } + + c.verifiedChains = chains + } + + c.peerCertificates = certs + c.ocspResponse = certificate.OCSPStaple + c.scts = certificate.SignedCertificateTimestamps + + if len(certs) > 0 { + switch certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + } + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +func newClientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { + supportedVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + supportedVersions = supportedVersionsFromMax(clientHello.vers) + } + + return toClientHelloInfo(&clientHelloInfo{ + CipherSuites: clientHello.cipherSuites, + ServerName: clientHello.serverName, + SupportedCurves: clientHello.supportedCurves, + SupportedPoints: clientHello.supportedPoints, + SignatureSchemes: clientHello.supportedSignatureAlgorithms, + SupportedProtos: clientHello.alpnProtocols, + SupportedVersions: supportedVersions, + Conn: c.conn, + config: toConfig(c.config), + }) +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server_tls13.go new file mode 100644 index 00000000..e1ab918d --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server_tls13.go @@ -0,0 +1,912 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto" + "crypto/hmac" + "crypto/rsa" + "errors" + "fmt" + "hash" + "io" + "sync/atomic" + "time" +) + +// maxClientPSKIdentities is the number of client PSK identities the server will +// attempt to validate. It will ignore the rest not to let cheap ClientHello +// messages cause too much work in session ticket decryption attempts. +const maxClientPSKIdentities = 5 + +type serverHandshakeStateTLS13 struct { + c *Conn + clientHello *clientHelloMsg + hello *serverHelloMsg + encryptedExtensions *encryptedExtensionsMsg + sentDummyCCS bool + usingPSK bool + suite *cipherSuiteTLS13 + cert *Certificate + sigAlg SignatureScheme + earlySecret []byte + sharedKey []byte + handshakeSecret []byte + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 + transcript hash.Hash + clientFinished []byte +} + +func (hs *serverHandshakeStateTLS13) handshake() error { + c := hs.c + + // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. + if err := hs.processClientHello(); err != nil { + return err + } + if err := hs.checkForResumption(); err != nil { + return err + } + if err := hs.pickCertificate(); err != nil { + return err + } + c.buffering = true + if err := hs.sendServerParameters(); err != nil { + return err + } + if err := hs.sendServerCertificate(); err != nil { + return err + } + if err := hs.sendServerFinished(); err != nil { + return err + } + // Note that at this point we could start sending application data without + // waiting for the client's second flight, but the application might not + // expect the lack of replay protection of the ClientHello parameters. + if _, err := c.flush(); err != nil { + return err + } + if err := hs.readClientCertificate(); err != nil { + return err + } + if err := hs.readClientFinished(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *serverHandshakeStateTLS13) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.encryptedExtensions = new(encryptedExtensionsMsg) + + // TLS 1.3 froze the ServerHello.legacy_version field, and uses + // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. + hs.hello.vers = VersionTLS12 + hs.hello.supportedVersion = c.vers + + if len(hs.clientHello.supportedVersions) == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") + } + + // Abort if the client is doing a fallback and landing lower than what we + // support. See RFC 7507, which however does not specify the interaction + // with supported_versions. The only difference is that with + // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] + // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, + // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to + // TLS 1.2, because a TLS 1.3 server would abort here. The situation before + // supported_versions was not better because there was just no way to do a + // TLS 1.4 handshake without risking the server selecting TLS 1.3. + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // Use c.vers instead of max(supported_versions) because an attacker + // could defeat this by adding an arbitrary high version otherwise. + if c.vers < c.config.maxSupportedVersion() { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + if len(hs.clientHello.compressionMethods) != 1 || + hs.clientHello.compressionMethods[0] != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: TLS 1.3 client supports illegal compression methods") + } + + hs.hello.random = make([]byte, 32) + if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.compressionMethod = compressionNone + + var preferenceList, supportedList, ourList []uint16 + var useConfiguredCipherSuites bool + for _, suiteID := range c.config.CipherSuites { + for _, suite := range cipherSuitesTLS13 { + if suite.id == suiteID { + ourList = append(ourList, suiteID) + break + } + } + } + if len(ourList) > 0 { + useConfiguredCipherSuites = true + } else { + ourList = defaultCipherSuitesTLS13() + } + if c.config.PreferServerCipherSuites { + preferenceList = ourList + supportedList = hs.clientHello.cipherSuites + + // If the client does not seem to have hardware support for AES-GCM, + // prefer other AEAD ciphers even if we prioritized AES-GCM ciphers + // by default. + if !useConfiguredCipherSuites && !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = deprioritizeAES(preferenceList) + } + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = ourList + + // If we don't have hardware support for AES-GCM, prefer other AEAD + // ciphers even if the client prioritized AES-GCM. + if !hasAESGCMHardwareSupport { + preferenceList = deprioritizeAES(preferenceList) + } + } + for _, suiteID := range preferenceList { + hs.suite = mutualCipherSuiteTLS13(supportedList, suiteID) + if hs.suite != nil { + break + } + } + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + hs.hello.cipherSuite = hs.suite.id + hs.transcript = hs.suite.hash.New() + + // Pick the ECDHE group in server preference order, but give priority to + // groups with a key share, to avoid a HelloRetryRequest round-trip. + var selectedGroup CurveID + var clientKeyShare *keyShare +GroupSelection: + for _, preferredGroup := range c.config.curvePreferences() { + for _, ks := range hs.clientHello.keyShares { + if ks.group == preferredGroup { + selectedGroup = ks.group + clientKeyShare = &ks + break GroupSelection + } + } + if selectedGroup != 0 { + continue + } + for _, group := range hs.clientHello.supportedCurves { + if group == preferredGroup { + selectedGroup = group + break + } + } + } + if selectedGroup == 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no ECDHE curve supported by both client and server") + } + if clientKeyShare == nil { + if err := hs.doHelloRetryRequest(selectedGroup); err != nil { + return err + } + clientKeyShare = &hs.clientHello.keyShares[0] + } + + if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} + hs.sharedKey = params.SharedKey(clientKeyShare.data) + if hs.sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid client key share") + } + + c.serverName = hs.clientHello.serverName + + if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil { + c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions) + } + + if len(hs.clientHello.alpnProtocols) > 0 { + if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" { + hs.encryptedExtensions.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + } + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) checkForResumption() error { + c := hs.c + + if c.config.SessionTicketsDisabled { + return nil + } + + modeOK := false + for _, mode := range hs.clientHello.pskModes { + if mode == pskModeDHE { + modeOK = true + break + } + } + if !modeOK { + return nil + } + + if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid or missing PSK binders") + } + if len(hs.clientHello.pskIdentities) == 0 { + return nil + } + + for i, identity := range hs.clientHello.pskIdentities { + if i >= maxClientPSKIdentities { + break + } + + plaintext, _ := c.decryptTicket(identity.label) + if plaintext == nil { + continue + } + sessionState := new(sessionStateTLS13) + if ok := sessionState.unmarshal(plaintext); !ok { + continue + } + + if hs.clientHello.earlyData { + if sessionState.maxEarlyData == 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: client sent unexpected early data") + } + + if sessionState.alpn == c.clientProtocol && + c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 && + c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) { + hs.encryptedExtensions.earlyData = true + c.used0RTT = true + } + } + + createdAt := time.Unix(int64(sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + continue + } + + // We don't check the obfuscated ticket age because it's affected by + // clock skew and it's only a freshness signal useful for shrinking the + // window for replay attacks, which don't affect us as we don't do 0-RTT. + + pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) + if pskSuite == nil || pskSuite.hash != hs.suite.hash { + continue + } + + // PSK connections don't re-establish client certificates, but carry + // them over in the session ticket. Ensure the presence of client certs + // in the ticket is consistent with the configured requirements. + sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + continue + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + continue + } + + psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", + nil, hs.suite.hash.Size()) + hs.earlySecret = hs.suite.extract(psk, nil) + binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) + // Clone the transcript in case a HelloRetryRequest was recorded. + transcript := cloneHash(hs.transcript, hs.suite.hash) + if transcript == nil { + c.sendAlert(alertInternalError) + return errors.New("tls: internal error: failed to clone hash") + } + transcript.Write(hs.clientHello.marshalWithoutBinders()) + pskBinder := hs.suite.finishedHash(binderKey, transcript) + if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid PSK binder") + } + + c.didResume = true + if err := c.processCertsFromClient(sessionState.certificate); err != nil { + return err + } + + h := cloneHash(hs.transcript, hs.suite.hash) + h.Write(hs.clientHello.marshal()) + if hs.encryptedExtensions.earlyData { + clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h) + c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret) + if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + } + + hs.hello.selectedIdentityPresent = true + hs.hello.selectedIdentity = uint16(i) + hs.usingPSK = true + return nil + } + + return nil +} + +// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// interfaces implemented by standard library hashes to clone the state of in +// to a new instance of h. It returns nil if the operation fails. +func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { + // Recreate the interface to avoid importing encoding. + type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) + UnmarshalBinary(data []byte) error + } + marshaler, ok := in.(binaryMarshaler) + if !ok { + return nil + } + state, err := marshaler.MarshalBinary() + if err != nil { + return nil + } + out := h.New() + unmarshaler, ok := out.(binaryMarshaler) + if !ok { + return nil + } + if err := unmarshaler.UnmarshalBinary(state); err != nil { + return nil + } + return out +} + +func (hs *serverHandshakeStateTLS13) pickCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. + if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { + return c.sendAlert(alertMissingExtension) + } + + certificate, err := c.config.getCertificate(newClientHelloInfo(c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) + if err != nil { + // getCertificate returned a certificate that is unsupported or + // incompatible with the client's signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + hs.cert = certificate + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. See RFC 8446, Section 4.4.1. + hs.transcript.Write(hs.clientHello.marshal()) + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + + helloRetryRequest := &serverHelloMsg{ + vers: hs.hello.vers, + random: helloRetryRequestRandom, + sessionId: hs.hello.sessionId, + cipherSuite: hs.hello.cipherSuite, + compressionMethod: hs.hello.compressionMethod, + supportedVersion: hs.hello.supportedVersion, + selectedGroup: selectedGroup, + } + + hs.transcript.Write(helloRetryRequest.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientHello, msg) + } + + if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client sent invalid key share in second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client indicated early data in second ClientHello") + } + + if illegalClientHelloChange(clientHello, hs.clientHello) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client illegally modified second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client offered 0-RTT data in second ClientHello") + } + + hs.clientHello = clientHello + return nil +} + +// illegalClientHelloChange reports whether the two ClientHello messages are +// different, with the exception of the changes allowed before and after a +// HelloRetryRequest. See RFC 8446, Section 4.1.2. +func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { + if len(ch.supportedVersions) != len(ch1.supportedVersions) || + len(ch.cipherSuites) != len(ch1.cipherSuites) || + len(ch.supportedCurves) != len(ch1.supportedCurves) || + len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || + len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || + len(ch.alpnProtocols) != len(ch1.alpnProtocols) { + return true + } + for i := range ch.supportedVersions { + if ch.supportedVersions[i] != ch1.supportedVersions[i] { + return true + } + } + for i := range ch.cipherSuites { + if ch.cipherSuites[i] != ch1.cipherSuites[i] { + return true + } + } + for i := range ch.supportedCurves { + if ch.supportedCurves[i] != ch1.supportedCurves[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithms { + if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithmsCert { + if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { + return true + } + } + for i := range ch.alpnProtocols { + if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { + return true + } + } + return ch.vers != ch1.vers || + !bytes.Equal(ch.random, ch1.random) || + !bytes.Equal(ch.sessionId, ch1.sessionId) || + !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || + ch.serverName != ch1.serverName || + ch.ocspStapling != ch1.ocspStapling || + !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || + ch.ticketSupported != ch1.ticketSupported || + !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || + ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || + !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || + ch.scts != ch1.scts || + !bytes.Equal(ch.cookie, ch1.cookie) || + !bytes.Equal(ch.pskModes, ch1.pskModes) +} + +func (hs *serverHandshakeStateTLS13) sendServerParameters() error { + c := hs.c + + if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection && len(c.clientProtocol) == 0 { + c.sendAlert(alertNoApplicationProtocol) + return fmt.Errorf("ALPN negotiation failed. Client offered: %q", hs.clientHello.alpnProtocols) + } + + hs.transcript.Write(hs.clientHello.marshal()) + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + earlySecret := hs.earlySecret + if earlySecret == nil { + earlySecret = hs.suite.extract(nil, nil) + } + hs.handshakeSecret = hs.suite.extract(hs.sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret) + c.in.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil { + hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions) + } + + hs.transcript.Write(hs.encryptedExtensions.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.encryptedExtensions.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) requestClientCert() bool { + return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK +} + +func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + if hs.requestClientCert() { + // Request a client certificate + certReq := new(certificateRequestMsgTLS13) + certReq.ocspStapling = true + certReq.scts = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + + hs.transcript.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *hs.cert + certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + certVerifyMsg.signatureAlgorithm = hs.sigAlg + + sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + public := hs.cert.PrivateKey.(crypto.Signer).Public() + if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && + rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS + c.sendAlert(alertHandshakeFailure) + } else { + c.sendAlert(alertInternalError) + } + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) sendServerFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + // Derive secrets that take context through the server Finished. + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionApplication, hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + // If we did not request client certificates, at this point we can + // precompute the client finished and roll the transcript forward to send + // session tickets in our first flight. + if !hs.requestClientCert() { + if err := hs.sendSessionTickets(); err != nil { + return err + } + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { + if hs.c.config.SessionTicketsDisabled { + return false + } + + // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. + for _, pskMode := range hs.clientHello.pskModes { + if pskMode == pskModeDHE { + return true + } + } + return false +} + +func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { + c := hs.c + + hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + finishedMsg := &finishedMsg{ + verifyData: hs.clientFinished, + } + hs.transcript.Write(finishedMsg.marshal()) + + if !hs.shouldSendSessionTickets() { + return nil + } + + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + + // Don't send session tickets when the alternative record layer is set. + // Instead, save the resumption secret on the Conn. + // Session tickets can then be generated by calling Conn.GetSessionTicket(). + if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil { + return nil + } + + m, err := hs.c.getSessionTicketMsg(nil) + if err != nil { + return err + } + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientCertificate() error { + c := hs.c + + if !hs.requestClientCert() { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + // If we requested a client certificate, then the client must send a + // certificate message. If it's empty, no CertificateVerify is sent. + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.transcript.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(certMsg.certificate); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if len(certMsg.certificate.Certificate) != 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + } + + // If we waited until the client certificates to send session tickets, we + // are ready to do it now. + if err := hs.sendSessionTickets(); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + if !hmac.Equal(hs.clientFinished, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid client finished hash") + } + + c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) + c.in.setTrafficSecret(hs.suite, hs.trafficSecret) + + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/key_agreement.go b/vendor/github.com/marten-seemann/qtls-go1-16/key_agreement.go new file mode 100644 index 00000000..d8f5d469 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/key_agreement.go @@ -0,0 +1,338 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/md5" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "errors" + "fmt" + "io" +) + +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + +// rsaKeyAgreement implements the standard TLS key agreement where the client +// encrypts the pre-master secret to the server's public key. +type rsaKeyAgreement struct{} + +func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + return nil, nil +} + +func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) + if ciphertextLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + ciphertext := ckx.ciphertext[2:] + + priv, ok := cert.PrivateKey.(crypto.Decrypter) + if !ok { + return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") + } + // Perform constant time RSA PKCS #1 v1.5 decryption + preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) + if err != nil { + return nil, err + } + // We don't check the version number in the premaster secret. For one, + // by checking it, we would leak information about the validity of the + // encrypted pre-master secret. Secondly, it provides only a small + // benefit against a downgrade attack and some implementations send the + // wrong version anyway. See the discussion at the end of section + // 7.4.7.1 of RFC 4346. + return preMasterSecret, nil +} + +func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + return errors.New("tls: unexpected ServerKeyExchange") +} + +func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + preMasterSecret := make([]byte, 48) + preMasterSecret[0] = byte(clientHello.vers >> 8) + preMasterSecret[1] = byte(clientHello.vers) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, nil, err + } + + rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") + } + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, len(encrypted)+2) + ckx.ciphertext[0] = byte(len(encrypted) >> 8) + ckx.ciphertext[1] = byte(len(encrypted)) + copy(ckx.ciphertext[2:], encrypted) + return preMasterSecret, ckx, nil +} + +// sha1Hash calculates a SHA1 hash over the given byte slices. +func sha1Hash(slices [][]byte) []byte { + hsha1 := sha1.New() + for _, slice := range slices { + hsha1.Write(slice) + } + return hsha1.Sum(nil) +} + +// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the +// concatenation of an MD5 and SHA1 hash. +func md5SHA1Hash(slices [][]byte) []byte { + md5sha1 := make([]byte, md5.Size+sha1.Size) + hmd5 := md5.New() + for _, slice := range slices { + hmd5.Write(slice) + } + copy(md5sha1, hmd5.Sum(nil)) + copy(md5sha1[md5.Size:], sha1Hash(slices)) + return md5sha1 +} + +// hashForServerKeyExchange hashes the given slices and returns their digest +// using the given hash function (for >= TLS 1.2) or using a default based on +// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't +// do pre-hashing, it returns the concatenation of the slices. +func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { + if sigType == signatureEd25519 { + var signed []byte + for _, slice := range slices { + signed = append(signed, slice...) + } + return signed + } + if version >= VersionTLS12 { + h := hashFunc.New() + for _, slice := range slices { + h.Write(slice) + } + digest := h.Sum(nil) + return digest + } + if sigType == signatureECDSA { + return sha1Hash(slices) + } + return md5SHA1Hash(slices) +} + +// ecdheKeyAgreement implements a TLS key agreement where the server +// generates an ephemeral EC public/private key pair and signs it. The +// pre-master secret is then calculated using ECDH. The signature may +// be ECDSA, Ed25519 or RSA. +type ecdheKeyAgreement struct { + version uint16 + isRSA bool + params ecdheParameters + + // ckx and preMasterSecret are generated in processServerKeyExchange + // and returned in generateClientKeyExchange. + ckx *clientKeyExchangeMsg + preMasterSecret []byte +} + +func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + var curveID CurveID + for _, c := range clientHello.supportedCurves { + if config.supportsCurve(c) { + curveID = c + break + } + } + + if curveID == 0 { + return nil, errors.New("tls: no supported elliptic curves offered") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, err + } + ka.params = params + + // See RFC 4492, Section 5.4. + ecdhePublic := params.PublicKey() + serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) + serverECDHEParams[0] = 3 // named curve + serverECDHEParams[1] = byte(curveID >> 8) + serverECDHEParams[2] = byte(curveID) + serverECDHEParams[3] = byte(len(ecdhePublic)) + copy(serverECDHEParams[4:], ecdhePublic) + + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) + } + + var signatureAlgorithm SignatureScheme + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) + if err != nil { + return nil, err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return nil, err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) + if err != nil { + return nil, err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) + + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := priv.Sign(config.rand(), signed, signOpts) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } + + skx := new(serverKeyExchangeMsg) + sigAndHashLen := 0 + if ka.version >= VersionTLS12 { + sigAndHashLen = 2 + } + skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) + copy(skx.key, serverECDHEParams) + k := skx.key[len(serverECDHEParams):] + if ka.version >= VersionTLS12 { + k[0] = byte(signatureAlgorithm >> 8) + k[1] = byte(signatureAlgorithm) + k = k[2:] + } + k[0] = byte(len(sig) >> 8) + k[1] = byte(len(sig)) + copy(k[2:], sig) + + return skx, nil +} + +func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { + return nil, errClientKeyExchange + } + + preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) + if preMasterSecret == nil { + return nil, errClientKeyExchange + } + + return preMasterSecret, nil +} + +func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 4 { + return errServerKeyExchange + } + if skx.key[0] != 3 { // named curve + return errors.New("tls: server selected unsupported curve") + } + curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + + publicLen := int(skx.key[3]) + if publicLen+4 > len(skx.key) { + return errServerKeyExchange + } + serverECDHEParams := skx.key[:4+publicLen] + publicKey := serverECDHEParams[4:] + + sig := skx.key[4+publicLen:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return errors.New("tls: server selected unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return err + } + ka.params = params + + ka.preMasterSecret = params.SharedKey(publicKey) + if ka.preMasterSecret == nil { + return errServerKeyExchange + } + + ourPublicKey := params.PublicKey() + ka.ckx = new(clientKeyExchangeMsg) + ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) + ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) + copy(ka.ckx.ciphertext[1:], ourPublicKey) + + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) + sig = sig[2:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) + if err != nil { + return err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return errServerKeyExchange + } + + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) + if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + return nil +} + +func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + if ka.ckx == nil { + return nil, nil, errors.New("tls: missing ServerKeyExchange message") + } + + return ka.preMasterSecret, ka.ckx, nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/key_schedule.go b/vendor/github.com/marten-seemann/qtls-go1-16/key_schedule.go new file mode 100644 index 00000000..da13904a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/key_schedule.go @@ -0,0 +1,199 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto/elliptic" + "crypto/hmac" + "errors" + "hash" + "io" + "math/big" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +// This file contains the functions necessary to compute the TLS 1.3 key +// schedule. See RFC 8446, Section 7. + +const ( + resumptionBinderLabel = "res binder" + clientHandshakeTrafficLabel = "c hs traffic" + serverHandshakeTrafficLabel = "s hs traffic" + clientApplicationTrafficLabel = "c ap traffic" + serverApplicationTrafficLabel = "s ap traffic" + exporterLabel = "exp master" + resumptionLabel = "res master" + trafficUpdateLabel = "traffic upd" +) + +// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("tls: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} + +// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { + if transcript == nil { + transcript = c.hash.New() + } + return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) +} + +// extract implements HKDF-Extract with the cipher suite hash. +func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { + if newSecret == nil { + newSecret = make([]byte, c.hash.Size()) + } + return hkdf.Extract(c.hash.New, newSecret, currentSecret) +} + +// nextTrafficSecret generates the next traffic secret, given the current one, +// according to RFC 8446, Section 7.2. +func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { + return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) +} + +// trafficKey generates traffic keys according to RFC 8446, Section 7.3. +func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { + key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) + iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) + return +} + +// finishedHash generates the Finished verify_data or PskBinderEntry according +// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey +// selection. +func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { + finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) + verifyData := hmac.New(c.hash.New, finishedKey) + verifyData.Write(transcript.Sum(nil)) + return verifyData.Sum(nil) +} + +// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to +// RFC 8446, Section 7.5. +func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { + expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) + return func(label string, context []byte, length int) ([]byte, error) { + secret := c.deriveSecret(expMasterSecret, label, nil) + h := c.hash.New() + h.Write(context) + return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil + } +} + +// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, +// according to RFC 8446, Section 4.2.8.2. +type ecdheParameters interface { + CurveID() CurveID + PublicKey() []byte + SharedKey(peerPublicKey []byte) []byte +} + +func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { + if curveID == X25519 { + privateKey := make([]byte, curve25519.ScalarSize) + if _, err := io.ReadFull(rand, privateKey); err != nil { + return nil, err + } + publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) + if err != nil { + return nil, err + } + return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil + } + + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, errors.New("tls: internal error: unsupported curve") + } + + p := &nistParameters{curveID: curveID} + var err error + p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) + if err != nil { + return nil, err + } + return p, nil +} + +func curveForCurveID(id CurveID) (elliptic.Curve, bool) { + switch id { + case CurveP256: + return elliptic.P256(), true + case CurveP384: + return elliptic.P384(), true + case CurveP521: + return elliptic.P521(), true + default: + return nil, false + } +} + +type nistParameters struct { + privateKey []byte + x, y *big.Int // public key + curveID CurveID +} + +func (p *nistParameters) CurveID() CurveID { + return p.curveID +} + +func (p *nistParameters) PublicKey() []byte { + curve, _ := curveForCurveID(p.curveID) + return elliptic.Marshal(curve, p.x, p.y) +} + +func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { + curve, _ := curveForCurveID(p.curveID) + // Unmarshal also checks whether the given point is on the curve. + x, y := elliptic.Unmarshal(curve, peerPublicKey) + if x == nil { + return nil + } + + xShared, _ := curve.ScalarMult(x, y, p.privateKey) + sharedKey := make([]byte, (curve.Params().BitSize+7)/8) + return xShared.FillBytes(sharedKey) +} + +type x25519Parameters struct { + privateKey []byte + publicKey []byte +} + +func (p *x25519Parameters) CurveID() CurveID { + return X25519 +} + +func (p *x25519Parameters) PublicKey() []byte { + return p.publicKey[:] +} + +func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { + sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) + if err != nil { + return nil + } + return sharedKey +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/prf.go b/vendor/github.com/marten-seemann/qtls-go1-16/prf.go new file mode 100644 index 00000000..9eb0221a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/prf.go @@ -0,0 +1,283 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" +) + +// Split a premaster secret in two as specified in RFC 4346, Section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return +} + +// pHash implements the P_hash function, as defined in RFC 4346, Section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} + +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. +func prf10(result, secret, label, seed []byte) { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + + for i, b := range result2 { + result[i] ^= b + } +} + +// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + pHash(result, secret, labelAndSeed, hashFunc) + } +} + +const ( + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. +) + +var masterSecretLabel = []byte("master secret") +var keyExpansionLabel = []byte("key expansion") +var clientFinishedLabel = []byte("client finished") +var serverFinishedLabel = []byte("server finished") + +func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { + switch version { + case VersionTLS10, VersionTLS11: + return prf10, crypto.Hash(0) + case VersionTLS12: + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 + default: + panic("unknown version") + } +} + +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + prf, _ := prfAndHashForVersion(version, suite) + return prf +} + +// masterFromPreMasterSecret generates the master secret from the pre-master +// secret. See RFC 5246, Section 8.1. +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { + seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) + return masterSecret +} + +// keysFromMasterSecret generates the connection keys from the master +// secret, given the lengths of the MAC key, cipher key and IV, as defined in +// RFC 2246, Section 6.3. +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { + seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) + seed = append(seed, serverRandom...) + seed = append(seed, clientRandom...) + + n := 2*macLen + 2*keyLen + 2*ivLen + keyMaterial := make([]byte, n) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) + clientMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + serverMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + clientKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + serverKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + clientIV = keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + serverIV = keyMaterial[:ivLen] + return +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + var buffer []byte + if version >= VersionTLS12 { + buffer = []byte{} + } + + prf, hash := prfAndHashForVersion(version, cipherSuite) + if hash != 0 { + return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} + } + + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +func (h *finishedHash) Write(msg []byte) (n int, err error) { + h.client.Write(msg) + h.server.Write(msg) + + if h.version < VersionTLS12 { + h.clientMD5.Write(msg) + h.serverMD5.Write(msg) + } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + + return len(msg), nil +} + +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + +// clientSum returns the contents of the verify_data member of a client's +// Finished message. +func (h finishedHash) clientSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) + return out +} + +// serverSum returns the contents of the verify_data member of a server's +// Finished message. +func (h finishedHash) serverSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) + return out +} + +// hashForClientCertificate returns the handshake messages so far, pre-hashed if +// necessary, suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { + if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { + panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") + } + + if sigType == signatureEd25519 { + return h.buffer + } + + if h.version >= VersionTLS12 { + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil) + } + + if sigType == signatureECDSA { + return h.server.Sum(nil) + } + + return h.Sum() +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil +} + +// noExportedKeyingMaterial is used as a value of +// ConnectionState.ekm when renegotiation is enabled and thus +// we wish to fail all key-material export requests. +func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") +} + +// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. +func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { + return func(label string, context []byte, length int) ([]byte, error) { + switch label { + case "client finished", "server finished", "master secret", "key expansion": + // These values are reserved and may not be used. + return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) + } + + seedLen := len(serverRandom) + len(clientRandom) + if context != nil { + seedLen += 2 + len(context) + } + seed := make([]byte, 0, seedLen) + + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + if context != nil { + if len(context) >= 1<<16 { + return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") + } + seed = append(seed, byte(len(context)>>8), byte(len(context))) + seed = append(seed, context...) + } + + keyMaterial := make([]byte, length) + prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) + return keyMaterial, nil + } +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/ticket.go b/vendor/github.com/marten-seemann/qtls-go1-16/ticket.go new file mode 100644 index 00000000..006b8c1d --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/ticket.go @@ -0,0 +1,259 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "errors" + "io" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +// sessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type sessionState struct { + vers uint16 + cipherSuite uint16 + createdAt uint64 + masterSecret []byte // opaque master_secret<1..2^16-1>; + // struct { opaque certificate<1..2^24-1> } Certificate; + certificates [][]byte // Certificate certificate_list<0..2^24-1>; + + // usedOldKey is true if the ticket from which this session came from + // was encrypted with an older key and thus should be refreshed. + usedOldKey bool +} + +func (m *sessionState) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(m.vers) + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.masterSecret) + }) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for _, cert := range m.certificates { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + } + }) + return b.BytesOrPanic() +} + +func (m *sessionState) unmarshal(data []byte) bool { + *m = sessionState{usedOldKey: m.usedOldKey} + s := cryptobyte.String(data) + if ok := s.ReadUint16(&m.vers) && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint16LengthPrefixed(&s, &m.masterSecret) && + len(m.masterSecret) != 0; !ok { + return false + } + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + if !readUint24LengthPrefixed(&certList, &cert) { + return false + } + m.certificates = append(m.certificates, cert) + } + return s.Empty() +} + +// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first +// version (revision = 0) doesn't carry any of the information needed for 0-RTT +// validation and the nonce is always empty. +// version (revision = 1) carries the max_early_data_size sent in the ticket. +// version (revision = 2) carries the ALPN sent in the ticket. +type sessionStateTLS13 struct { + // uint8 version = 0x0304; + // uint8 revision = 2; + cipherSuite uint16 + createdAt uint64 + resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; + certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; + maxEarlyData uint32 + alpn string + + appData []byte +} + +func (m *sessionStateTLS13) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(VersionTLS13) + b.AddUint8(2) // revision + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.resumptionSecret) + }) + marshalCertificate(&b, m.certificate) + b.AddUint32(m.maxEarlyData) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpn)) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.appData) + }) + return b.BytesOrPanic() +} + +func (m *sessionStateTLS13) unmarshal(data []byte) bool { + *m = sessionStateTLS13{} + s := cryptobyte.String(data) + var version uint16 + var revision uint8 + var alpn []byte + ret := s.ReadUint16(&version) && + version == VersionTLS13 && + s.ReadUint8(&revision) && + revision == 2 && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint8LengthPrefixed(&s, &m.resumptionSecret) && + len(m.resumptionSecret) != 0 && + unmarshalCertificate(&s, &m.certificate) && + s.ReadUint32(&m.maxEarlyData) && + readUint8LengthPrefixed(&s, &alpn) && + readUint16LengthPrefixed(&s, &m.appData) && + s.Empty() + m.alpn = string(alpn) + return ret +} + +func (c *Conn) encryptTicket(state []byte) ([]byte, error) { + if len(c.ticketKeys) == 0 { + return nil, errors.New("tls: internal error: session ticket keys unavailable") + } + + encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + return nil, err + } + key := c.ticketKeys[0] + copy(keyName, key.keyName[:]) + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { + if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { + return nil, false + } + + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] + + keyIndex := -1 + for i, candidateKey := range c.ticketKeys { + if bytes.Equal(keyName, candidateKey.keyName[:]) { + keyIndex = i + break + } + } + if keyIndex == -1 { + return nil, false + } + key := &c.ticketKeys[keyIndex] + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + expected := mac.Sum(nil) + + if subtle.ConstantTimeCompare(macBytes, expected) != 1 { + return nil, false + } + + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, false + } + plaintext = make([]byte, len(ciphertext)) + cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) + + return plaintext, keyIndex > 0 +} + +func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) { + m := new(newSessionTicketMsgTLS13) + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionStateTLS13{ + cipherSuite: c.cipherSuite, + createdAt: uint64(c.config.time().Unix()), + resumptionSecret: c.resumptionSecret, + certificate: Certificate{ + Certificate: certsFromClient, + OCSPStaple: c.ocspResponse, + SignedCertificateTimestamps: c.scts, + }, + appData: appData, + alpn: c.clientProtocol, + } + if c.extraConfig != nil { + state.maxEarlyData = c.extraConfig.MaxEarlyData + } + var err error + m.label, err = c.encryptTicket(state.marshal()) + if err != nil { + return nil, err + } + m.lifetime = uint32(maxSessionTicketLifetime / time.Second) + if c.extraConfig != nil { + m.maxEarlyData = c.extraConfig.MaxEarlyData + } + return m, nil +} + +// GetSessionTicket generates a new session ticket. +// It should only be called after the handshake completes. +// It can only be used for servers, and only if the alternative record layer is set. +// The ticket may be nil if config.SessionTicketsDisabled is set, +// or if the client isn't able to receive session tickets. +func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { + if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.") + } + if c.config.SessionTicketsDisabled { + return nil, nil + } + + m, err := c.getSessionTicketMsg(appData) + if err != nil { + return nil, err + } + return m.marshal(), nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/tls.go b/vendor/github.com/marten-seemann/qtls-go1-16/tls.go new file mode 100644 index 00000000..10cbae03 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/tls.go @@ -0,0 +1,393 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// package qtls partially implements TLS 1.2, as specified in RFC 5246, +// and TLS 1.3, as specified in RFC 8446. +package qtls + +// BUG(agl): The crypto/tls package only implements some countermeasures +// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 +// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net" + "os" + "strings" + "time" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + c := &Conn{ + conn: conn, + config: fromConfig(config), + extraConfig: extraConfig, + } + c.handshakeFn = c.serverHandshake + return c +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + c := &Conn{ + conn: conn, + config: fromConfig(config), + extraConfig: extraConfig, + isClient: true, + } + c.handshakeFn = c.clientHandshake + return c +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type listener struct { + net.Listener + config *Config + extraConfig *ExtraConfig +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection is of type *Conn. +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return Server(c, l.config, l.extraConfig), nil +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener { + l := new(listener) + l.Listener = inner + l.config = config + l.extraConfig = extraConfig + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) { + if config == nil || len(config.Certificates) == 0 && + config.GetCertificate == nil && config.GetConfigForClient == nil { + return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config, extraConfig), nil +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + return dial(context.Background(), dialer, network, addr, config, extraConfig) +} + +func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := netDialer.Timeout + + if !netDialer.Deadline.IsZero() { + deadlineTimeout := time.Until(netDialer.Deadline) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + // hsErrCh is non-nil if we might not wait for Handshake to complete. + var hsErrCh chan error + if timeout != 0 || ctx.Done() != nil { + hsErrCh = make(chan error, 2) + } + if timeout != 0 { + timer := time.AfterFunc(timeout, func() { + hsErrCh <- timeoutError{} + }) + defer timer.Stop() + } + + rawConn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = defaultConfig() + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := config.Clone() + c.ServerName = hostname + config = c + } + + conn := Client(rawConn, config, extraConfig) + + if hsErrCh == nil { + err = conn.Handshake() + } else { + go func() { + hsErrCh <- conn.Handshake() + }() + + select { + case <-ctx.Done(): + err = ctx.Err() + case err = <-hsErrCh: + if err != nil { + // If the error was due to the context + // closing, prefer the context's error, rather + // than some random network teardown error. + if e := ctx.Err(); e != nil { + err = e + } + } + } + } + + if err != nil { + rawConn.Close() + return nil, err + } + + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig) +} + +// Dialer dials TLS connections given a configuration and a Dialer for the +// underlying connection. +type Dialer struct { + // NetDialer is the optional dialer to use for the TLS connections' + // underlying TCP connections. + // A nil NetDialer is equivalent to the net.Dialer zero value. + NetDialer *net.Dialer + + // Config is the TLS configuration to use for new connections. + // A nil configuration is equivalent to the zero + // configuration; see the documentation of Config for the + // defaults. + Config *Config + + ExtraConfig *ExtraConfig +} + +// Dial connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} + +func (d *Dialer) netDialer() *net.Dialer { + if d.NetDialer != nil { + return d.NetDialer + } + return new(net.Dialer) +} + +// DialContext connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig) + if err != nil { + // Don't return c (a typed nil) in an interface. + return nil, err + } + return c, nil +} + +// LoadX509KeyPair reads and parses a public/private key pair from a pair +// of files. The files must contain PEM encoded data. The certificate file +// may contain intermediate certificates following the leaf certificate to +// form a certificate chain. On successful return, Certificate.Leaf will +// be nil because the parsed form of the certificate is not retained. +func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { + certPEMBlock, err := os.ReadFile(certFile) + if err != nil { + return Certificate{}, err + } + keyPEMBlock, err := os.ReadFile(keyFile) + if err != nil { + return Certificate{}, err + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} + +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. On successful return, Certificate.Leaf will be nil because +// the parsed form of the certificate is not retained. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + fail := func(err error) (Certificate, error) { return Certificate{}, err } + + var cert Certificate + var skippedBlockTypes []string + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } else { + skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) + } + } + + if len(cert.Certificate) == 0 { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in certificate input")) + } + if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { + return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) + } + return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + + skippedBlockTypes = skippedBlockTypes[:0] + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) + if keyDERBlock == nil { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in key input")) + } + if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { + return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) + } + return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) + } + + // We don't need to parse the public key for TLS, but we so do anyway + // to check that it looks sane and matches the private key. + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fail(err) + } + + cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) + if err != nil { + return fail(err) + } + + switch pub := x509Cert.PublicKey.(type) { + case *rsa.PublicKey: + priv, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.N.Cmp(priv.N) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case *ecdsa.PublicKey: + priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case ed25519.PublicKey: + priv, ok := cert.PrivateKey.(ed25519.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { + return fail(errors.New("tls: private key does not match public key")) + } + default: + return fail(errors.New("tls: unknown public key algorithm")) + } + + return cert, nil +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("tls: failed to parse private key") +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/unsafe.go b/vendor/github.com/marten-seemann/qtls-go1-16/unsafe.go new file mode 100644 index 00000000..55fa01b3 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-16/unsafe.go @@ -0,0 +1,96 @@ +package qtls + +import ( + "crypto/tls" + "reflect" + "unsafe" +) + +func init() { + if !structsEqual(&tls.ConnectionState{}, &connectionState{}) { + panic("qtls.ConnectionState doesn't match") + } + if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) { + panic("qtls.ClientSessionState doesn't match") + } + if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) { + panic("qtls.CertificateRequestInfo doesn't match") + } + if !structsEqual(&tls.Config{}, &config{}) { + panic("qtls.Config doesn't match") + } + if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { + panic("qtls.ClientHelloInfo doesn't match") + } +} + +func toConnectionState(c connectionState) ConnectionState { + return *(*ConnectionState)(unsafe.Pointer(&c)) +} + +func toClientSessionState(s *clientSessionState) *ClientSessionState { + return (*ClientSessionState)(unsafe.Pointer(s)) +} + +func fromClientSessionState(s *ClientSessionState) *clientSessionState { + return (*clientSessionState)(unsafe.Pointer(s)) +} + +func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo { + return (*CertificateRequestInfo)(unsafe.Pointer(i)) +} + +func toConfig(c *config) *Config { + return (*Config)(unsafe.Pointer(c)) +} + +func fromConfig(c *Config) *config { + return (*config)(unsafe.Pointer(c)) +} + +func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo { + return (*ClientHelloInfo)(unsafe.Pointer(chi)) +} + +func structsEqual(a, b interface{}) bool { + return compare(reflect.ValueOf(a), reflect.ValueOf(b)) +} + +func compare(a, b reflect.Value) bool { + sa := a.Elem() + sb := b.Elem() + if sa.NumField() != sb.NumField() { + return false + } + for i := 0; i < sa.NumField(); i++ { + fa := sa.Type().Field(i) + fb := sb.Type().Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + if fa.Type.Kind() != fb.Type.Kind() { + return false + } + if fa.Type.Kind() == reflect.Slice { + if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) { + return false + } + continue + } + return false + } + } + return true +} + +func compareStruct(a, b reflect.Type) bool { + if a.NumField() != b.NumField() { + return false + } + for i := 0; i < a.NumField(); i++ { + fa := a.Field(i) + fb := b.Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + return false + } + } + return true +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/LICENSE b/vendor/github.com/marten-seemann/qtls-go1-17/LICENSE new file mode 100644 index 00000000..6a66aea5 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/README.md b/vendor/github.com/marten-seemann/qtls-go1-17/README.md new file mode 100644 index 00000000..3e902212 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/README.md @@ -0,0 +1,6 @@ +# qtls + +[![Go Reference](https://pkg.go.dev/badge/github.com/marten-seemann/qtls-go1-17.svg)](https://pkg.go.dev/github.com/marten-seemann/qtls-go1-17) +[![.github/workflows/go-test.yml](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml/badge.svg)](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml) + +This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/alert.go b/vendor/github.com/marten-seemann/qtls-go1-17/alert.go new file mode 100644 index 00000000..3feac79b --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/alert.go @@ -0,0 +1,102 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import "strconv" + +type alert uint8 + +// Alert is a TLS alert +type Alert = alert + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertExportRestriction alert = 60 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertInappropriateFallback alert = 86 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 + alertMissingExtension alert = 109 + alertUnsupportedExtension alert = 110 + alertCertificateUnobtainable alert = 111 + alertUnrecognizedName alert = 112 + alertBadCertificateStatusResponse alert = 113 + alertBadCertificateHashValue alert = 114 + alertUnknownPSKIdentity alert = 115 + alertCertificateRequired alert = 116 + alertNoApplicationProtocol alert = 120 +) + +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertExportRestriction: "export restriction", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertInappropriateFallback: "inappropriate fallback", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", + alertMissingExtension: "missing extension", + alertUnsupportedExtension: "unsupported extension", + alertCertificateUnobtainable: "certificate unobtainable", + alertUnrecognizedName: "unrecognized name", + alertBadCertificateStatusResponse: "bad certificate status response", + alertBadCertificateHashValue: "bad certificate hash value", + alertUnknownPSKIdentity: "unknown PSK identity", + alertCertificateRequired: "certificate required", + alertNoApplicationProtocol: "no application protocol", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return "tls: " + s + } + return "tls: alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/auth.go b/vendor/github.com/marten-seemann/qtls-go1-17/auth.go new file mode 100644 index 00000000..1ef675fd --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/auth.go @@ -0,0 +1,289 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "errors" + "fmt" + "hash" + "io" +) + +// verifyHandshakeSignature verifies a signature against pre-hashed +// (if required) handshake contents. +func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) + } + if !ecdsa.VerifyASN1(pubKey, signed, sig) { + return errors.New("ECDSA verification failure") + } + case signatureEd25519: + pubKey, ok := pubkey.(ed25519.PublicKey) + if !ok { + return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) + } + if !ed25519.Verify(pubKey, signed, sig) { + return errors.New("Ed25519 verification failure") + } + case signaturePKCS1v15: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { + return err + } + case signatureRSAPSS: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { + return err + } + default: + return errors.New("internal error: unknown signature type") + } + return nil +} + +const ( + serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" + clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" +) + +var signaturePadding = []byte{ + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, +} + +// signedMessage returns the pre-hashed (if necessary) message to be signed by +// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. +func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { + if sigHash == directSigning { + b := &bytes.Buffer{} + b.Write(signaturePadding) + io.WriteString(b, context) + b.Write(transcript.Sum(nil)) + return b.Bytes() + } + h := sigHash.New() + h.Write(signaturePadding) + io.WriteString(h, context) + h.Write(transcript.Sum(nil)) + return h.Sum(nil) +} + +// typeAndHashFromSignatureScheme returns the corresponding signature type and +// crypto.Hash for a given TLS SignatureScheme. +func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { + switch signatureAlgorithm { + case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: + sigType = signaturePKCS1v15 + case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: + sigType = signatureRSAPSS + case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: + sigType = signatureECDSA + case Ed25519: + sigType = signatureEd25519 + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + switch signatureAlgorithm { + case PKCS1WithSHA1, ECDSAWithSHA1: + hash = crypto.SHA1 + case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: + hash = crypto.SHA256 + case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: + hash = crypto.SHA384 + case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: + hash = crypto.SHA512 + case Ed25519: + hash = directSigning + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + return sigType, hash, nil +} + +// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for +// a given public key used with TLS 1.0 and 1.1, before the introduction of +// signature algorithm negotiation. +func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { + switch pub.(type) { + case *rsa.PublicKey: + return signaturePKCS1v15, crypto.MD5SHA1, nil + case *ecdsa.PublicKey: + return signatureECDSA, crypto.SHA1, nil + case ed25519.PublicKey: + // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, + // but it requires holding on to a handshake transcript to do a + // full signature, and not even OpenSSL bothers with the + // complexity, so we can't even test it properly. + return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") + default: + return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) + } +} + +var rsaSignatureSchemes = []struct { + scheme SignatureScheme + minModulusBytes int + maxVersion uint16 +}{ + // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires + // emLen >= hLen + sLen + 2 + {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, + // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires + // emLen >= len(prefix) + hLen + 11 + // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. + {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, + {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, + {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, + {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, +} + +// signatureSchemesForCertificate returns the list of supported SignatureSchemes +// for a given certificate, based on the public key and the protocol version, +// and optionally filtered by its explicit SupportedSignatureAlgorithms. +// +// This function must be kept in sync with supportedSignatureAlgorithms. +func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil + } + + var sigAlgs []SignatureScheme + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + if version != VersionTLS13 { + // In TLS 1.2 and earlier, ECDSA algorithms are not + // constrained to a single curve. + sigAlgs = []SignatureScheme{ + ECDSAWithP256AndSHA256, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + ECDSAWithSHA1, + } + break + } + switch pub.Curve { + case elliptic.P256(): + sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} + case elliptic.P384(): + sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} + case elliptic.P521(): + sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} + default: + return nil + } + case *rsa.PublicKey: + size := pub.Size() + sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) + for _, candidate := range rsaSignatureSchemes { + if size >= candidate.minModulusBytes && version <= candidate.maxVersion { + sigAlgs = append(sigAlgs, candidate.scheme) + } + } + case ed25519.PublicKey: + sigAlgs = []SignatureScheme{Ed25519} + default: + return nil + } + + if cert.SupportedSignatureAlgorithms != nil { + var filteredSigAlgs []SignatureScheme + for _, sigAlg := range sigAlgs { + if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { + filteredSigAlgs = append(filteredSigAlgs, sigAlg) + } + } + return filteredSigAlgs + } + return sigAlgs +} + +// selectSignatureScheme picks a SignatureScheme from the peer's preference list +// that works with the selected certificate. It's only called for protocol +// versions that support signature algorithms, so TLS 1.2 and 1.3. +func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { + supportedAlgs := signatureSchemesForCertificate(vers, c) + if len(supportedAlgs) == 0 { + return 0, unsupportedCertificateError(c) + } + if len(peerAlgs) == 0 && vers == VersionTLS12 { + // For TLS 1.2, if the client didn't send signature_algorithms then we + // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. + peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} + } + // Pick signature scheme in the peer's preference order, as our + // preference order is not configurable. + for _, preferredAlg := range peerAlgs { + if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { + return preferredAlg, nil + } + } + return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") +} + +// unsupportedCertificateError returns a helpful error for certificates with +// an unsupported private key. +func unsupportedCertificateError(cert *Certificate) error { + switch cert.PrivateKey.(type) { + case rsa.PrivateKey, ecdsa.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", + cert.PrivateKey, cert.PrivateKey) + case *ed25519.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") + } + + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", + cert.PrivateKey) + } + + switch pub := signer.Public().(type) { + case *ecdsa.PublicKey: + switch pub.Curve { + case elliptic.P256(): + case elliptic.P384(): + case elliptic.P521(): + default: + return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) + } + case *rsa.PublicKey: + return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") + case ed25519.PublicKey: + default: + return fmt.Errorf("tls: unsupported certificate key (%T)", pub) + } + + if cert.SupportedSignatureAlgorithms != nil { + return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") + } + + return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/cipher_suites.go b/vendor/github.com/marten-seemann/qtls-go1-17/cipher_suites.go new file mode 100644 index 00000000..53a3956a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/cipher_suites.go @@ -0,0 +1,691 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/hmac" + "crypto/rc4" + "crypto/sha1" + "crypto/sha256" + "fmt" + "hash" + + "golang.org/x/crypto/chacha20poly1305" +) + +// CipherSuite is a TLS cipher suite. Note that most functions in this package +// accept and expose cipher suite IDs instead of this type. +type CipherSuite struct { + ID uint16 + Name string + + // Supported versions is the list of TLS protocol versions that can + // negotiate this cipher suite. + SupportedVersions []uint16 + + // Insecure is true if the cipher suite has known security issues + // due to its primitives, design, or implementation. + Insecure bool +} + +var ( + supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} + supportedOnlyTLS12 = []uint16{VersionTLS12} + supportedOnlyTLS13 = []uint16{VersionTLS13} +) + +// CipherSuites returns a list of cipher suites currently implemented by this +// package, excluding those with security issues, which are returned by +// InsecureCipherSuites. +// +// The list is sorted by ID. Note that the default cipher suites selected by +// this package might depend on logic that can't be captured by a static list, +// and might not match those returned by this function. +func CipherSuites() []*CipherSuite { + return []*CipherSuite{ + {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + + {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, + {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, + {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, + + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + } +} + +// InsecureCipherSuites returns a list of cipher suites currently implemented by +// this package and which have security issues. +// +// Most applications should not use the cipher suites in this list, and should +// only use those returned by CipherSuites. +func InsecureCipherSuites() []*CipherSuite { + // This list includes RC4, CBC_SHA256, and 3DES cipher suites. See + // cipherSuitesPreferenceOrder for details. + return []*CipherSuite{ + {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + } +} + +// CipherSuiteName returns the standard name for the passed cipher suite ID +// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation +// of the ID value if the cipher suite is not implemented by this package. +func CipherSuiteName(id uint16) string { + for _, c := range CipherSuites() { + if c.ID == id { + return c.Name + } + } + for _, c := range InsecureCipherSuites() { + if c.ID == id { + return c.Name + } + } + return fmt.Sprintf("0x%04X", id) +} + +const ( + // suiteECDHE indicates that the cipher suite involves elliptic curve + // Diffie-Hellman. This means that it should only be selected when the + // client indicates that it supports ECC with a curve and point format + // that we're happy with. + suiteECDHE = 1 << iota + // suiteECSign indicates that the cipher suite involves an ECDSA or + // EdDSA signature and therefore may only be selected when the server's + // certificate is ECDSA or EdDSA. If this is not set then the cipher suite + // is RSA based. + suiteECSign + // suiteTLS12 indicates that the cipher suite should only be advertised + // and accepted when using TLS 1.2. + suiteTLS12 + // suiteSHA384 indicates that the cipher suite uses SHA384 as the + // handshake hash. + suiteSHA384 +) + +// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange +// mechanism, as well as the cipher+MAC pair or the AEAD. +type cipherSuite struct { + id uint16 + // the lengths, in bytes, of the key material needed for each component. + keyLen int + macLen int + ivLen int + ka func(version uint16) keyAgreement + // flags is a bitmask of the suite* values, above. + flags int + cipher func(key, iv []byte, isRead bool) interface{} + mac func(key []byte) hash.Hash + aead func(key, fixedNonce []byte) aead +} + +var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter. + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil}, +} + +// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which +// is also in supportedIDs and passes the ok filter. +func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { + for _, id := range ids { + candidate := cipherSuiteByID(id) + if candidate == nil || !ok(candidate) { + continue + } + + for _, suppID := range supportedIDs { + if id == suppID { + return candidate + } + } + } + return nil +} + +// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash +// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. +type cipherSuiteTLS13 struct { + id uint16 + keyLen int + aead func(key, fixedNonce []byte) aead + hash crypto.Hash +} + +type CipherSuiteTLS13 struct { + ID uint16 + KeyLen int + Hash crypto.Hash + AEAD func(key, fixedNonce []byte) cipher.AEAD +} + +func (c *CipherSuiteTLS13) IVLen() int { + return aeadNonceLength +} + +var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map. + {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, + {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, + {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, +} + +// cipherSuitesPreferenceOrder is the order in which we'll select (on the +// server) or advertise (on the client) TLS 1.0–1.2 cipher suites. +// +// Cipher suites are filtered but not reordered based on the application and +// peer's preferences, meaning we'll never select a suite lower in this list if +// any higher one is available. This makes it more defensible to keep weaker +// cipher suites enabled, especially on the server side where we get the last +// word, since there are no known downgrade attacks on cipher suites selection. +// +// The list is sorted by applying the following priority rules, stopping at the +// first (most important) applicable one: +// +// - Anything else comes before RC4 +// +// RC4 has practically exploitable biases. See https://www.rc4nomore.com. +// +// - Anything else comes before CBC_SHA256 +// +// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 +// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. +// +// - Anything else comes before 3DES +// +// 3DES has 64-bit blocks, which makes it fundamentally susceptible to +// birthday attacks. See https://sweet32.info. +// +// - ECDHE comes before anything else +// +// Once we got the broken stuff out of the way, the most important +// property a cipher suite can have is forward secrecy. We don't +// implement FFDHE, so that means ECDHE. +// +// - AEADs come before CBC ciphers +// +// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites +// are fundamentally fragile, and suffered from an endless sequence of +// padding oracle attacks. See https://eprint.iacr.org/2015/1129, +// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and +// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. +// +// - AES comes before ChaCha20 +// +// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster +// than ChaCha20Poly1305. +// +// When AES hardware is not available, AES-128-GCM is one or more of: much +// slower, way more complex, and less safe (because not constant time) +// than ChaCha20Poly1305. +// +// We use this list if we think both peers have AES hardware, and +// cipherSuitesPreferenceOrderNoAES otherwise. +// +// - AES-128 comes before AES-256 +// +// The only potential advantages of AES-256 are better multi-target +// margins, and hypothetical post-quantum properties. Neither apply to +// TLS, and AES-256 is slower due to its four extra rounds (which don't +// contribute to the advantages above). +// +// - ECDSA comes before RSA +// +// The relative order of ECDSA and RSA cipher suites doesn't matter, +// as they depend on the certificate. Pick one to get a stable order. +// +var cipherSuitesPreferenceOrder = []uint16{ + // AEADs w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + + // CBC w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + + // AEADs w/o ECDHE + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + + // CBC w/o ECDHE + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + + // 3DES + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_RSA_WITH_3DES_EDE_CBC_SHA, + + // CBC_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + + // RC4 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +var cipherSuitesPreferenceOrderNoAES = []uint16{ + // ChaCha20Poly1305 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + + // AES-GCM w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + + // The rest of cipherSuitesPreferenceOrder. + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +// disabledCipherSuites are not used unless explicitly listed in +// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder. +var disabledCipherSuites = []uint16{ + // CBC_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + + // RC4 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +var ( + defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) + defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen] +) + +// defaultCipherSuitesTLS13 is also the preference order, since there are no +// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as +// cipherSuitesPreferenceOrder applies. +var defaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, +} + +var defaultCipherSuitesTLS13NoAES = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, +} + +var aesgcmCiphers = map[uint16]bool{ + // TLS 1.2 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, + // TLS 1.3 + TLS_AES_128_GCM_SHA256: true, + TLS_AES_256_GCM_SHA384: true, +} + +var nonAESGCMAEADCiphers = map[uint16]bool{ + // TLS 1.2 + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, + // TLS 1.3 + TLS_CHACHA20_POLY1305_SHA256: true, +} + +// aesgcmPreferred returns whether the first known cipher in the preference list +// is an AES-GCM cipher, implying the peer has hardware support for it. +func aesgcmPreferred(ciphers []uint16) bool { + for _, cID := range ciphers { + if c := cipherSuiteByID(cID); c != nil { + return aesgcmCiphers[cID] + } + if c := cipherSuiteTLS13ByID(cID); c != nil { + return aesgcmCiphers[cID] + } + } + return false +} + +func cipherRC4(key, iv []byte, isRead bool) interface{} { + cipher, _ := rc4.NewCipher(key) + return cipher +} + +func cipher3DES(key, iv []byte, isRead bool) interface{} { + block, _ := des.NewTripleDESCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +func cipherAES(key, iv []byte, isRead bool) interface{} { + block, _ := aes.NewCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +// macSHA1 returns a SHA-1 based constant time MAC. +func macSHA1(key []byte) hash.Hash { + return hmac.New(newConstantTimeHash(sha1.New), key) +} + +// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and +// is currently only used in disabled-by-default cipher suites. +func macSHA256(key []byte) hash.Hash { + return hmac.New(sha256.New, key) +} + +type aead interface { + cipher.AEAD + + // explicitNonceLen returns the number of bytes of explicit nonce + // included in each record. This is eight for older AEADs and + // zero for modern ones. + explicitNonceLen() int +} + +const ( + aeadNonceLength = 12 + noncePrefixLength = 4 +) + +// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +// each call. +type prefixNonceAEAD struct { + // nonce contains the fixed part of the nonce in the first four bytes. + nonce [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } +func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } + +func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + copy(f.nonce[4:], nonce) + return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) +} + +func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + copy(f.nonce[4:], nonce) + return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) +} + +// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// before each call. +type xorNonceAEAD struct { + nonceMask [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result, err +} + +func aeadAESGCM(key, noncePrefix []byte) aead { + if len(noncePrefix) != noncePrefixLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &prefixNonceAEAD{aead: aead} + copy(ret.nonce[:], noncePrefix) + return ret +} + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return aeadAESGCMTLS13(key, fixedNonce) +} + +func aeadAESGCMTLS13(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +func aeadChaCha20Poly1305(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + +// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. +func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { + h.Reset() + h.Write(seq) + h.Write(header) + h.Write(data) + res := h.Sum(out) + if extra != nil { + h.Write(extra) + } + return res +} + +func rsaKA(version uint16) keyAgreement { + return rsaKeyAgreement{} +} + +func ecdheECDSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: false, + version: version, + } +} + +func ecdheRSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: true, + version: version, + } +} + +// mutualCipherSuite returns a cipherSuite given a list of supported +// ciphersuites and the id requested by the peer. +func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { + for _, id := range have { + if id == want { + return cipherSuiteByID(id) + } + } + return nil +} + +func cipherSuiteByID(id uint16) *cipherSuite { + for _, cipherSuite := range cipherSuites { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { + for _, id := range have { + if id == want { + return cipherSuiteTLS13ByID(id) + } + } + return nil +} + +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { + for _, cipherSuite := range cipherSuitesTLS13 { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +// A list of cipher suite IDs that are, or have been, implemented by this +// package. +// +// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml +const ( + // TLS 1.0 - 1.2 cipher suites. + TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 + + // TLS 1.3 cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = 0x1301 + TLS_AES_256_GCM_SHA384 uint16 = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 + + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator + // that the client is doing version fallback. See RFC 7507. + TLS_FALLBACK_SCSV uint16 = 0x5600 + + // Legacy names for the corresponding cipher suites with the correct _SHA256 + // suffix, retained for backward compatibility. + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/common.go b/vendor/github.com/marten-seemann/qtls-go1-17/common.go new file mode 100644 index 00000000..8a9b6804 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/common.go @@ -0,0 +1,1485 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "container/list" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "strings" + "sync" + "time" +) + +const ( + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 + VersionTLS13 = 0x0304 + + // Deprecated: SSLv3 is cryptographically broken, and is no longer + // supported by this package. See golang.org/issue/32716. + VersionSSL30 = 0x0300 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxUselessRecords = 16 // maximum number of consecutive non-advancing records +) + +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 + typeNextProtocol uint8 = 67 // Not IANA assigned + typeMessageHash uint8 = 254 // synthetic message +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +type Extension struct { + Type uint16 + Data []byte +} + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionRenegotiationInfo uint16 = 0xff01 +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +type EncryptionLevel uint8 + +const ( + EncryptionHandshake EncryptionLevel = iota + Encryption0RTT + EncryptionApplication +) + +// CurveID is a tls.CurveID +type CurveID = tls.CurveID + +const ( + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + X25519 CurveID = 29 +) + +// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. +type keyShare struct { + group CurveID + data []byte +} + +// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. +const ( + pskModePlain uint8 = 0 + pskModeDHE uint8 = 1 +) + +// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved +// session. See RFC 8446, Section 4.2.11. +type pskIdentity struct { + label []byte + obfuscatedTicketAge uint32 +} + +// TLS Elliptic Curve Point Formats +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + certTypeRSASign = 1 + certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. +) + +// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with +// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. +const ( + signaturePKCS1v15 uint8 = iota + 225 + signatureRSAPSS + signatureECDSA + signatureEd25519 +) + +// directSigning is a standard Hash value that signals that no pre-hashing +// should be performed, and that the input should be signed directly. It is the +// hash function associated with the Ed25519 signature scheme. +var directSigning crypto.Hash = 0 + +// supportedSignatureAlgorithms contains the signature and hash algorithms that +// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ +// CertificateRequest. The two fields are merged to match with TLS 1.3. +// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. +var supportedSignatureAlgorithms = []SignatureScheme{ + PSSWithSHA256, + ECDSAWithP256AndSHA256, + Ed25519, + PSSWithSHA384, + PSSWithSHA512, + PKCS1WithSHA256, + PKCS1WithSHA384, + PKCS1WithSHA512, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + PKCS1WithSHA1, + ECDSAWithSHA1, +} + +// helloRetryRequestRandom is set as the Random value of a ServerHello +// to signal that the message is actually a HelloRetryRequest. +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + +const ( + // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server + // random as a downgrade protection if the server would be capable of + // negotiating a higher version. See RFC 8446, Section 4.1.3. + downgradeCanaryTLS12 = "DOWNGRD\x01" + downgradeCanaryTLS11 = "DOWNGRD\x00" +) + +// testingOnlyForceDowngradeCanary is set in tests to force the server side to +// include downgrade canaries even if it's using its highers supported version. +var testingOnlyForceDowngradeCanary bool + +type ConnectionState = tls.ConnectionState + +// ConnectionState records basic TLS details about the connection. +type connectionState struct { + // Version is the TLS version used by the connection (e.g. VersionTLS12). + Version uint16 + + // HandshakeComplete is true if the handshake has concluded. + HandshakeComplete bool + + // DidResume is true if this connection was successfully resumed from a + // previous session with a session ticket or similar mechanism. + DidResume bool + + // CipherSuite is the cipher suite negotiated for the connection (e.g. + // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). + CipherSuite uint16 + + // NegotiatedProtocol is the application protocol negotiated with ALPN. + NegotiatedProtocol string + + // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. + // + // Deprecated: this value is always true. + NegotiatedProtocolIsMutual bool + + // ServerName is the value of the Server Name Indication extension sent by + // the client. It's available both on the server and on the client side. + ServerName string + + // PeerCertificates are the parsed certificates sent by the peer, in the + // order in which they were sent. The first element is the leaf certificate + // that the connection is verified against. + // + // On the client side, it can't be empty. On the server side, it can be + // empty if Config.ClientAuth is not RequireAnyClientCert or + // RequireAndVerifyClientCert. + PeerCertificates []*x509.Certificate + + // VerifiedChains is a list of one or more chains where the first element is + // PeerCertificates[0] and the last element is from Config.RootCAs (on the + // client side) or Config.ClientCAs (on the server side). + // + // On the client side, it's set if Config.InsecureSkipVerify is false. On + // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven + // (and the peer provided a certificate) or RequireAndVerifyClientCert. + VerifiedChains [][]*x509.Certificate + + // SignedCertificateTimestamps is a list of SCTs provided by the peer + // through the TLS handshake for the leaf certificate, if any. + SignedCertificateTimestamps [][]byte + + // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) + // response provided by the peer for the leaf certificate, if any. + OCSPResponse []byte + + // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, + // Section 3). This value will be nil for TLS 1.3 connections and for all + // resumed connections. + // + // Deprecated: there are conditions in which this value might not be unique + // to a connection. See the Security Considerations sections of RFC 5705 and + // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. + TLSUnique []byte + + // ekm is a closure exposed via ExportKeyingMaterial. + ekm func(label string, context []byte, length int) ([]byte, error) +} + +type ConnectionStateWith0RTT struct { + ConnectionState + + Used0RTT bool // true if 0-RTT was both offered and accepted +} + +// ClientAuthType is tls.ClientAuthType +type ClientAuthType = tls.ClientAuthType + +const ( + NoClientCert = tls.NoClientCert + RequestClientCert = tls.RequestClientCert + RequireAnyClientCert = tls.RequireAnyClientCert + VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven + RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert +) + +// requiresClientCert reports whether the ClientAuthType requires a client +// certificate to be provided. +func requiresClientCert(c ClientAuthType) bool { + switch c { + case RequireAnyClientCert, RequireAndVerifyClientCert: + return true + default: + return false + } +} + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState = tls.ClientSessionState + +type clientSessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification + receivedAt time.Time // When the session ticket was received from the server + ocspResponse []byte // Stapled OCSP response presented by the server + scts [][]byte // SCTs presented by the server + + // TLS 1.3 fields. + nonce []byte // Ticket nonce sent by the server, to derive PSK + useBy time.Time // Expiration of the ticket lifetime as set by the server + ageAdd uint32 // Random obfuscation factor for sending the ticket age +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not +// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which +// are supported via this interface. +//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/marten-seemann/qtls-go1-17 ClientSessionCache" +type ClientSessionCache = tls.ClientSessionCache + +// SignatureScheme is a tls.SignatureScheme +type SignatureScheme = tls.SignatureScheme + +const ( + // RSASSA-PKCS1-v1_5 algorithms. + PKCS1WithSHA256 SignatureScheme = 0x0401 + PKCS1WithSHA384 SignatureScheme = 0x0501 + PKCS1WithSHA512 SignatureScheme = 0x0601 + + // RSASSA-PSS algorithms with public key OID rsaEncryption. + PSSWithSHA256 SignatureScheme = 0x0804 + PSSWithSHA384 SignatureScheme = 0x0805 + PSSWithSHA512 SignatureScheme = 0x0806 + + // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. + ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 + ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 + ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 + + // EdDSA algorithms. + Ed25519 SignatureScheme = 0x0807 + + // Legacy signature and hash algorithms for TLS 1.2. + PKCS1WithSHA1 SignatureScheme = 0x0201 + ECDSAWithSHA1 SignatureScheme = 0x0203 +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide application logic in the GetCertificate and GetConfigForClient callbacks. +type ClientHelloInfo = tls.ClientHelloInfo + +type clientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see RFC 4366, Section 3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see RFC 4492, Section 5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see RFC 4492, Section 5.1.2). + SupportedPoints []uint8 + + // SignatureSchemes lists the signature and hash schemes that the client + // is willing to verify. SignatureSchemes is set only if the Signature + // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). + SignatureSchemes []SignatureScheme + + // SupportedProtos lists the application protocols supported by the client. + // SupportedProtos is set only if the Application-Layer Protocol + // Negotiation Extension is being used (see RFC 7301, Section 3.1). + // + // Servers can select a protocol by setting Config.NextProtos in a + // GetConfigForClient return value. + SupportedProtos []string + + // SupportedVersions lists the TLS versions supported by the client. + // For TLS versions less than 1.3, this is extrapolated from the max + // version advertised by the client, so values other than the greatest + // might be rejected if used. + SupportedVersions []uint16 + + // Conn is the underlying net.Conn for the connection. Do not read + // from, or write to, this connection; that will cause the TLS + // connection to fail. + Conn net.Conn + + // config is embedded by the GetCertificate or GetConfigForClient caller, + // for use with SupportsCertificate. + config *Config + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *clientHelloInfo) Context() context.Context { + return c.ctx +} + +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo = tls.CertificateRequestInfo + +type certificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte + + // SignatureSchemes lists the signature schemes that the server is + // willing to verify. + SignatureSchemes []SignatureScheme + + // Version is the TLS version that was negotiated for this connection. + Version uint16 + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *certificateRequestInfo) Context() context.Context { + return c.ctx +} + +// RenegotiationSupport enumerates the different levels of support for TLS +// renegotiation. TLS renegotiation is the act of performing subsequent +// handshakes on a connection after the first. This significantly complicates +// the state machine and has been the source of numerous, subtle security +// issues. Initiating a renegotiation is not supported, but support for +// accepting renegotiation requests may be enabled. +// +// Even when enabled, the server may not change its identity between handshakes +// (i.e. the leaf certificate must be the same). Additionally, concurrent +// handshake and application data flow is not permitted so renegotiation can +// only be used with protocols that synchronise with the renegotiation, such as +// HTTPS. +// +// Renegotiation is not defined in TLS 1.3. +type RenegotiationSupport = tls.RenegotiationSupport + +const ( + // RenegotiateNever disables renegotiation. + RenegotiateNever = tls.RenegotiateNever + + // RenegotiateOnceAsClient allows a remote server to request + // renegotiation once per connection. + RenegotiateOnceAsClient = tls.RenegotiateOnceAsClient + + // RenegotiateFreelyAsClient allows a remote server to repeatedly + // request renegotiation. + RenegotiateFreelyAsClient = tls.RenegotiateFreelyAsClient +) + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config = tls.Config + +type config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains to present to the + // other side of the connection. The first certificate compatible with the + // peer's requirements is selected automatically. + // + // Server configurations must set one of Certificates, GetCertificate or + // GetConfigForClient. Clients doing client-authentication may set either + // Certificates or GetClientCertificate. + // + // Note: if there are multiple Certificates, and they don't have the + // optional field Leaf set, certificate selection will incur a significant + // per-handshake performance cost. + Certificates []Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // + // Deprecated: NameToCertificate only allows associating a single + // certificate with a given name. Leave this field nil to let the library + // select the first compatible chain from Certificates. + NameToCertificate map[string]*Certificate + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // best element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // If SessionTicketKey was explicitly set on the returned Config, or if + // SetSessionTicketKeys was called on the returned Config, those keys will + // be used. Otherwise, the original Config keys will be used (and possibly + // rotated if they are automatically managed). + GetConfigForClient func(*ClientHelloInfo) (*Config, error) + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // VerifyConnection, if not nil, is called after normal certificate + // verification and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(ConnectionState) error + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported application level protocols, in + // order of preference. If both peers support ALPN, the selected + // protocol will be one from this list, and the connection will fail + // if there is no mutually supported protocol. If NextProtos is empty + // or the peer doesn't support ALPN, the connection will succeed and + // ConnectionState.NegotiatedProtocol will be empty. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting unless it is + // an IP address. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // InsecureSkipVerify controls whether a client verifies the server's + // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls + // accepts any certificate presented by the server and any host name in that + // certificate. In this mode, TLS is susceptible to machine-in-the-middle + // attacks unless custom verification is used. This should be used only for + // testing or in combination with VerifyConnection or VerifyPeerCertificate. + InsecureSkipVerify bool + + // CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of + // the list is ignored. Note that TLS 1.3 ciphersuites are not configurable. + // + // If CipherSuites is nil, a safe default list is used. The default cipher + // suites might change over time. + CipherSuites []uint16 + + // PreferServerCipherSuites is a legacy field and has no effect. + // + // It used to control whether the server would follow the client's or the + // server's preference. Servers now select the best mutually supported + // cipher suite based on logic that takes into account inferred client + // hardware, server hardware, and security. + // + // Deprected: PreferServerCipherSuites is ignored. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket and + // PSK (resumption) support. Note that on clients, session ticket support is + // also disabled if ClientSessionCache is nil. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session resumption. + // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled + // with random data before the first server handshake. + // + // Deprecated: if this field is left at zero, session ticket keys will be + // automatically rotated every day and dropped after seven days. For + // customizing the rotation schedule or synchronizing servers that are + // terminating connections for the same host, use SetSessionTicketKeys. + SessionTicketKey [32]byte + + // ClientSessionCache is a cache of ClientSessionState entries for TLS + // session resumption. It is only used by clients. + ClientSessionCache ClientSessionCache + + // MinVersion contains the minimum TLS version that is acceptable. + // If zero, TLS 1.0 is currently taken as the minimum. + MinVersion uint16 + + // MaxVersion contains the maximum TLS version that is acceptable. + // If zero, the maximum version supported by this package is used, + // which is currently TLS 1.3. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. The client will use the first preference as the type for + // its key share in TLS 1.3. This may change in the future. + CurvePreferences []CurveID + + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. + // When true, the largest possible TLS record size is always used. When + // false, the size of TLS records may be adjusted in an attempt to + // improve latency. + DynamicRecordSizingDisabled bool + + // Renegotiation controls what types of renegotiation are supported. + // The default, none, is correct for the vast majority of applications. + Renegotiation RenegotiationSupport + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // mutex protects sessionTicketKeys and autoSessionTicketKeys. + mutex sync.RWMutex + // sessionTicketKeys contains zero or more ticket keys. If set, it means the + // the keys were set with SessionTicketKey or SetSessionTicketKeys. The + // first key is used for new tickets and any subsequent keys can be used to + // decrypt old tickets. The slice contents are not protected by the mutex + // and are immutable. + sessionTicketKeys []ticketKey + // autoSessionTicketKeys is like sessionTicketKeys but is owned by the + // auto-rotation logic. See Config.ticketKeys. + autoSessionTicketKeys []ticketKey +} + +// A RecordLayer handles encrypting and decrypting of TLS messages. +type RecordLayer interface { + SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) + SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) + ReadHandshakeMessage() ([]byte, error) + WriteRecord([]byte) (int, error) + SendAlert(uint8) +} + +type ExtraConfig struct { + // GetExtensions, if not nil, is called before a message that allows + // sending of extensions is sent. + // Currently only implemented for the ClientHello message (for the client) + // and for the EncryptedExtensions message (for the server). + // Only valid for TLS 1.3. + GetExtensions func(handshakeMessageType uint8) []Extension + + // ReceivedExtensions, if not nil, is called when a message that allows the + // inclusion of extensions is received. + // It is called with an empty slice of extensions, if the message didn't + // contain any extensions. + // Currently only implemented for the ClientHello message (sent by the + // client) and for the EncryptedExtensions message (sent by the server). + // Only valid for TLS 1.3. + ReceivedExtensions func(handshakeMessageType uint8, exts []Extension) + + // AlternativeRecordLayer is used by QUIC + AlternativeRecordLayer RecordLayer + + // Enforce the selection of a supported application protocol. + // Only works for TLS 1.3. + // If enabled, client and server have to agree on an application protocol. + // Otherwise, connection establishment fails. + EnforceNextProtoSelection bool + + // If MaxEarlyData is greater than 0, the client will be allowed to send early + // data when resuming a session. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning on the client. + MaxEarlyData uint32 + + // The Accept0RTT callback is called when the client offers 0-RTT. + // The server then has to decide if it wants to accept or reject 0-RTT. + // It is only used for servers. + Accept0RTT func(appData []byte) bool + + // 0RTTRejected is called when the server rejectes 0-RTT. + // It is only used for clients. + Rejected0RTT func() + + // If set, the client will export the 0-RTT key when resuming a session that + // allows sending of early data. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning to the server. + Enable0RTT bool + + // Is called when the client saves a session ticket to the session ticket. + // This gives the application the opportunity to save some data along with the ticket, + // which can be restored when the session ticket is used. + GetAppDataForSessionState func() []byte + + // Is called when the client uses a session ticket. + // Restores the application data that was saved earlier on GetAppDataForSessionTicket. + SetAppDataFromSessionState func([]byte) +} + +// Clone clones. +func (c *ExtraConfig) Clone() *ExtraConfig { + return &ExtraConfig{ + GetExtensions: c.GetExtensions, + ReceivedExtensions: c.ReceivedExtensions, + AlternativeRecordLayer: c.AlternativeRecordLayer, + EnforceNextProtoSelection: c.EnforceNextProtoSelection, + MaxEarlyData: c.MaxEarlyData, + Enable0RTT: c.Enable0RTT, + Accept0RTT: c.Accept0RTT, + Rejected0RTT: c.Rejected0RTT, + GetAppDataForSessionState: c.GetAppDataForSessionState, + SetAppDataFromSessionState: c.SetAppDataFromSessionState, + } +} + +func (c *ExtraConfig) usesAlternativeRecordLayer() bool { + return c != nil && c.AlternativeRecordLayer != nil +} + +const ( + // ticketKeyNameLen is the number of bytes of identifier that is prepended to + // an encrypted session ticket in order to identify the key used to encrypt it. + ticketKeyNameLen = 16 + + // ticketKeyLifetime is how long a ticket key remains valid and can be used to + // resume a client connection. + ticketKeyLifetime = 7 * 24 * time.Hour // 7 days + + // ticketKeyRotation is how often the server should rotate the session ticket key + // that is used for new tickets. + ticketKeyRotation = 24 * time.Hour +) + +// ticketKey is the internal representation of a session ticket key. +type ticketKey struct { + // keyName is an opaque byte string that serves to identify the session + // ticket key. It's exposed as plaintext in every session ticket. + keyName [ticketKeyNameLen]byte + aesKey [16]byte + hmacKey [16]byte + // created is the time at which this ticket key was created. See Config.ticketKeys. + created time.Time +} + +// ticketKeyFromBytes converts from the external representation of a session +// ticket key to a ticketKey. Externally, session ticket keys are 32 random +// bytes and this function expands that into sufficient name and key material. +func (c *config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { + hashed := sha512.Sum512(b[:]) + copy(key.keyName[:], hashed[:ticketKeyNameLen]) + copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) + copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) + key.created = c.time() + return key +} + +// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session +// ticket, and the lifetime we set for tickets we send. +const maxSessionTicketLifetime = 7 * 24 * time.Hour + +// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *config) Clone() *config { + if c == nil { + return nil + } + c.mutex.RLock() + defer c.mutex.RUnlock() + return &config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + VerifyConnection: c.VerifyConnection, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + sessionTicketKeys: c.sessionTicketKeys, + autoSessionTicketKeys: c.autoSessionTicketKeys, + } +} + +// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was +// randomized for backwards compatibility but is not in use. +var deprecatedSessionTicketKey = []byte("DEPRECATED") + +// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is +// randomized if empty, and that sessionTicketKeys is populated from it otherwise. +func (c *config) initLegacySessionTicketKeyRLocked() { + // Don't write if SessionTicketKey is already defined as our deprecated string, + // or if it is defined by the user but sessionTicketKeys is already set. + if c.SessionTicketKey != [32]byte{} && + (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { + return + } + + // We need to write some data, so get an exclusive lock and re-check any conditions. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + if c.SessionTicketKey == [32]byte{} { + if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) + } + // Write the deprecated prefix at the beginning so we know we created + // it. This key with the DEPRECATED prefix isn't used as an actual + // session ticket key, and is only randomized in case the application + // reuses it for some reason. + copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) + } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { + c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} + } + +} + +// ticketKeys returns the ticketKeys for this connection. +// If configForClient has explicitly set keys, those will +// be returned. Otherwise, the keys on c will be used and +// may be rotated if auto-managed. +// During rotation, any expired session ticket keys are deleted from +// c.sessionTicketKeys. If the session ticket key that is currently +// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) +// is not fresh, then a new session ticket key will be +// created and prepended to c.sessionTicketKeys. +func (c *config) ticketKeys(configForClient *config) []ticketKey { + // If the ConfigForClient callback returned a Config with explicitly set + // keys, use those, otherwise just use the original Config. + if configForClient != nil { + configForClient.mutex.RLock() + if configForClient.SessionTicketsDisabled { + return nil + } + configForClient.initLegacySessionTicketKeyRLocked() + if len(configForClient.sessionTicketKeys) != 0 { + ret := configForClient.sessionTicketKeys + configForClient.mutex.RUnlock() + return ret + } + configForClient.mutex.RUnlock() + } + + c.mutex.RLock() + defer c.mutex.RUnlock() + if c.SessionTicketsDisabled { + return nil + } + c.initLegacySessionTicketKeyRLocked() + if len(c.sessionTicketKeys) != 0 { + return c.sessionTicketKeys + } + // Fast path for the common case where the key is fresh enough. + if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { + return c.autoSessionTicketKeys + } + + // autoSessionTicketKeys are managed by auto-rotation. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + // Re-check the condition in case it changed since obtaining the new lock. + if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { + var newKey [32]byte + if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { + panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) + } + valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) + valid = append(valid, c.ticketKeyFromBytes(newKey)) + for _, k := range c.autoSessionTicketKeys { + // While rotating the current key, also remove any expired ones. + if c.time().Sub(k.created) < ticketKeyLifetime { + valid = append(valid, k) + } + } + c.autoSessionTicketKeys = valid + } + return c.autoSessionTicketKeys +} + +// SetSessionTicketKeys updates the session ticket keys for a server. +// +// The first key will be used when creating new tickets, while all keys can be +// used for decrypting tickets. It is safe to call this function while the +// server is running in order to rotate the session ticket keys. The function +// will panic if keys is empty. +// +// Calling this function will turn off automatic session ticket key rotation. +// +// If multiple servers are terminating connections for the same host they should +// all have the same session ticket keys. If the session ticket keys leaks, +// previously recorded and future TLS connections using those keys might be +// compromised. +func (c *config) SetSessionTicketKeys(keys [][32]byte) { + if len(keys) == 0 { + panic("tls: keys must have at least one key") + } + + newKeys := make([]ticketKey, len(keys)) + for i, bytes := range keys { + newKeys[i] = c.ticketKeyFromBytes(bytes) + } + + c.mutex.Lock() + c.sessionTicketKeys = newKeys + c.mutex.Unlock() +} + +func (c *config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func (c *config) cipherSuites() []uint16 { + if c.CipherSuites != nil { + return c.CipherSuites + } + return defaultCipherSuites +} + +var supportedVersions = []uint16{ + VersionTLS13, + VersionTLS12, + VersionTLS11, + VersionTLS10, +} + +func (c *config) supportedVersions() []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if c != nil && c.MinVersion != 0 && v < c.MinVersion { + continue + } + if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +func (c *config) maxSupportedVersion() uint16 { + supportedVersions := c.supportedVersions() + if len(supportedVersions) == 0 { + return 0 + } + return supportedVersions[0] +} + +// supportedVersionsFromMax returns a list of supported versions derived from a +// legacy maximum version value. Note that only versions supported by this +// library are returned. Any newer peer will use supportedVersions anyway. +func supportedVersionsFromMax(maxVersion uint16) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if v > maxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *config) curvePreferences() []CurveID { + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +func (c *config) supportsCurve(curve CurveID) bool { + for _, cc := range c.curvePreferences() { + if cc == curve { + return true + } + } + return false +} + +// mutualVersion returns the protocol version to use given the advertised +// versions of the peer. Priority is given to the peer preference order. +func (c *config) mutualVersion(peerVersions []uint16) (uint16, bool) { + supportedVersions := c.supportedVersions() + for _, peerVersion := range peerVersions { + for _, v := range supportedVersions { + if v == peerVersion { + return v, true + } + } + } + return 0, false +} + +var errNoCertificates = errors.New("tls: no certificates configured") + +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } + } + + if len(c.Certificates) == 0 { + return nil, errNoCertificates + } + + if len(c.Certificates) == 1 { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + if c.NameToCertificate != nil { + name := strings.ToLower(clientHello.ServerName) + if cert, ok := c.NameToCertificate[name]; ok { + return cert, nil + } + if len(name) > 0 { + labels := strings.Split(name, ".") + labels[0] = "*" + wildcardName := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[wildcardName]; ok { + return cert, nil + } + } + } + + for _, cert := range c.Certificates { + if err := clientHello.SupportsCertificate(&cert); err == nil { + return &cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0], nil +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the client that sent the ClientHello. Otherwise, it returns an error +// describing the reason for the incompatibility. +// +// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate +// callback, this method will take into account the associated Config. Note that +// if GetConfigForClient returns a different Config, the change can't be +// accounted for by this method. +// +// This function will call x509.ParseCertificate unless c.Leaf is set, which can +// incur a significant performance cost. +func (chi *clientHelloInfo) SupportsCertificate(c *Certificate) error { + // Note we don't currently support certificate_authorities nor + // signature_algorithms_cert, and don't check the algorithms of the + // signatures on the chain (which anyway are a SHOULD, see RFC 8446, + // Section 4.4.2.2). + + config := chi.config + if config == nil { + config = &Config{} + } + conf := fromConfig(config) + vers, ok := conf.mutualVersion(chi.SupportedVersions) + if !ok { + return errors.New("no mutually supported protocol versions") + } + + // If the client specified the name they are trying to connect to, the + // certificate needs to be valid for it. + if chi.ServerName != "" { + x509Cert, err := leafCertificate(c) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { + return fmt.Errorf("certificate is not valid for requested server name: %w", err) + } + } + + // supportsRSAFallback returns nil if the certificate and connection support + // the static RSA key exchange, and unsupported otherwise. The logic for + // supporting static RSA is completely disjoint from the logic for + // supporting signed key exchanges, so we just check it as a fallback. + supportsRSAFallback := func(unsupported error) error { + // TLS 1.3 dropped support for the static RSA key exchange. + if vers == VersionTLS13 { + return unsupported + } + // The static RSA key exchange works by decrypting a challenge with the + // RSA private key, not by signing, so check the PrivateKey implements + // crypto.Decrypter, like *rsa.PrivateKey does. + if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { + if _, ok := priv.Public().(*rsa.PublicKey); !ok { + return unsupported + } + } else { + return unsupported + } + // Finally, there needs to be a mutual cipher suite that uses the static + // RSA key exchange instead of ECDHE. + rsaCipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + return false + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if rsaCipherSuite == nil { + return unsupported + } + return nil + } + + // If the client sent the signature_algorithms extension, ensure it supports + // schemes we can use with this certificate and TLS version. + if len(chi.SignatureSchemes) > 0 { + if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { + return supportsRSAFallback(err) + } + } + + // In TLS 1.3 we are done because supported_groups is only relevant to the + // ECDHE computation, point format negotiation is removed, cipher suites are + // only relevant to the AEAD choice, and static RSA does not exist. + if vers == VersionTLS13 { + return nil + } + + // The only signed key exchange we support is ECDHE. + if !supportsECDHE(conf, chi.SupportedCurves, chi.SupportedPoints) { + return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) + } + + var ecdsaCipherSuite bool + if priv, ok := c.PrivateKey.(crypto.Signer); ok { + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + var curve CurveID + switch pub.Curve { + case elliptic.P256(): + curve = CurveP256 + case elliptic.P384(): + curve = CurveP384 + case elliptic.P521(): + curve = CurveP521 + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + var curveOk bool + for _, c := range chi.SupportedCurves { + if c == curve && conf.supportsCurve(c) { + curveOk = true + break + } + } + if !curveOk { + return errors.New("client doesn't support certificate curve") + } + ecdsaCipherSuite = true + case ed25519.PublicKey: + if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { + return errors.New("connection doesn't support Ed25519") + } + ecdsaCipherSuite = true + case *rsa.PublicKey: + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + } else { + return supportsRSAFallback(unsupportedCertificateError(c)) + } + + // Make sure that there is a mutually supported cipher suite that works with + // this certificate. Cipher suite selection will then apply the logic in + // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. + cipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE == 0 { + return false + } + if c.flags&suiteECSign != 0 { + if !ecdsaCipherSuite { + return false + } + } else { + if ecdsaCipherSuite { + return false + } + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if cipherSuite == nil { + return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) + } + + return nil +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +// +// Deprecated: NameToCertificate only allows associating a single certificate +// with a given name. Leave that field nil to let the library select the first +// compatible chain from Certificates. +func (c *config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := leafCertificate(cert) + if err != nil { + continue + } + // If SANs are *not* present, some clients will consider the certificate + // valid for the name in the Common Name. + if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +const ( + keyLogLabelTLS12 = "CLIENT_RANDOM" + keyLogLabelEarlyTraffic = "CLIENT_EARLY_TRAFFIC_SECRET" + keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" + keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" +) + +func (c *config) writeKeyLog(label string, clientRandom, secret []byte) error { + if c.KeyLogWriter == nil { + return nil + } + + logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) + + writerMutex.Lock() + _, err := c.KeyLogWriter.Write(logLine) + writerMutex.Unlock() + + return err +} + +// writerMutex protects all KeyLogWriters globally. It is rarely enabled, +// and is only for debugging, so a global mutex saves space. +var writerMutex sync.Mutex + +// A Certificate is a chain of one or more certificates, leaf first. +type Certificate = tls.Certificate + +// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing +// the corresponding c.Certificate[0]. +func leafCertificate(c *Certificate) (*x509.Certificate, error) { + if c.Leaf != nil { + return c.Leaf, nil + } + return x509.ParseCertificate(c.Certificate[0]) +} + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) bool +} + +// lruSessionCache is a ClientSessionCache implementation that uses an LRU +// caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state *ClientSessionState +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + } +} + +// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry +// corresponding to sessionKey is removed from the cache instead. +func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + if cs == nil { + c.q.Remove(elem) + delete(c.m, sessionKey) + } else { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + } + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the ClientSessionState value associated with a given key. It +// returns (nil, false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +func unexpectedMessageError(wanted, got interface{}) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { + for _, s := range supportedSignatureAlgorithms { + if s == sigAlg { + return true + } + } + return false +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/conn.go b/vendor/github.com/marten-seemann/qtls-go1-17/conn.go new file mode 100644 index 00000000..70fad465 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/conn.go @@ -0,0 +1,1601 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package qtls + +import ( + "bytes" + "context" + "crypto/cipher" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn net.Conn + isClient bool + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake + + // handshakeStatus is 1 if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). + // handshakeStatus == 1 implies handshakeErr == nil. + // This field is only to be accessed with sync/atomic. + handshakeStatus uint32 + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *config // configuration passed to constructor + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + extraConfig *ExtraConfig + + handshakes int + didResume bool // whether this connection was a session resumption + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // signed certificate timestamps from server + peerCertificates []*x509.Certificate + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // serverName contains the server name indicated by the client, if any. + serverName string + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + // ekm is a closure for exporting keying material. + ekm func(label string, context []byte, length int) ([]byte, error) + // For the client: + // resumptionSecret is the resumption_master_secret for handling + // NewSessionTicket messages. nil if config.SessionTicketsDisabled. + // For the server: + // resumptionSecret is the resumption_master_secret for generating + // NewSessionTicket messages. Only used when the alternative record + // layer is set. nil if config.SessionTicketsDisabled. + resumptionSecret []byte + + // ticketKeys is the set of active session ticket keys for this + // connection. The first one is used to encrypt new tickets and + // all are tried to decrypt tickets. + ticketKeys []ticketKey + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte + + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string + + // input/output + in, out halfConn + rawInput bytes.Buffer // raw input, starting with a record header + input bytes.Reader // application data waiting to be read, from rawInput.Next + hand bytes.Buffer // handshake data waiting to be read + buffering bool // whether records are buffered in sendBuf + sendBuf []byte // a buffer of records waiting to be sent + + // bytesSent counts the bytes of application data sent. + // packetsSent counts packets. + bytesSent int64 + packetsSent int64 + + // retryCount counts the number of consecutive non-advancing records + // received by Conn.readRecord. That is, records that neither advance the + // handshake, nor deliver application data. Protected by in.Mutex. + retryCount int + + // activeCall is an atomic int32; the low bit is whether Close has + // been called. the rest of the bits are the number of goroutines + // in Conn.Write. + activeCall int32 + + used0RTT bool + + tmp [16]byte +} + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + + err error // first permanent error + version uint16 // protocol version + cipher interface{} // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape + + nextCipher interface{} // next encryption state + nextMac hash.Hash // next MAC algorithm + + trafficSecret []byte // current TLS 1.3 traffic secret + + setKeyCallback func(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) +} + +type permanentError struct { + err net.Error +} + +func (e *permanentError) Error() string { return e.err.Error() } +func (e *permanentError) Unwrap() error { return e.err } +func (e *permanentError) Timeout() bool { return e.err.Timeout() } +func (e *permanentError) Temporary() bool { return false } + +func (hc *halfConn) setErrorLocked(err error) error { + if e, ok := err.(net.Error); ok { + hc.err = &permanentError{err: e} + } else { + hc.err = err + } + return hc.err +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { + hc.version = version + hc.nextCipher = cipher + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec() error { + if hc.nextCipher == nil || hc.version == VersionTLS13 { + return alertInternalError + } + hc.cipher = hc.nextCipher + hc.mac = hc.nextMac + hc.nextCipher = nil + hc.nextMac = nil + for i := range hc.seq { + hc.seq[i] = 0 + } + return nil +} + +func (hc *halfConn) exportKey(encLevel EncryptionLevel, suite *cipherSuiteTLS13, trafficSecret []byte) { + if hc.setKeyCallback != nil { + s := &CipherSuiteTLS13{ + ID: suite.id, + KeyLen: suite.keyLen, + Hash: suite.hash, + AEAD: func(key, fixedNonce []byte) cipher.AEAD { return suite.aead(key, fixedNonce) }, + } + hc.setKeyCallback(encLevel, s, trafficSecret) + } +} + +func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { + hc.trafficSecret = secret + key, iv := suite.trafficKey(secret) + hc.cipher = suite.aead(key, iv) + for i := range hc.seq { + hc.seq[i] = 0 + } +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq() { + for i := 7; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +// explicitNonceLen returns the number of bytes of explicit nonce or IV included +// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 +// and in certain AEAD modes in TLS 1.2. +func (hc *halfConn) explicitNonceLen() int { + if hc.cipher == nil { + return 0 + } + + switch c := hc.cipher.(type) { + case cipher.Stream: + return 0 + case aead: + return c.explicitNonceLen() + case cbcMode: + // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. + if hc.version >= VersionTLS11 { + return c.BlockSize() + } + return 0 + default: + panic("unknown cipher type") + } +} + +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. +func extractPadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) + + // The maximum possible padding length plus the actual length field + toCheck := 256 + // The length of the padded data is public, so we can use an if here + if toCheck > len(payload) { + toCheck = len(payload) + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + // Zero the padding length on error. This ensures any unchecked bytes + // are included in the MAC. Otherwise, an attacker that could + // distinguish MAC failures from padding failures could mount an attack + // similar to POODLE in SSL 3.0: given a good ciphertext that uses a + // full block's worth of padding, replace the final block with another + // block. If the MAC check passed but the padding check failed, the + // last byte of that block decrypted to the block size. + // + // See also macAndPaddingGood logic below. + paddingLen &= good + + toRemove = int(paddingLen) + 1 + return +} + +func roundUp(a, b int) int { + return a + (b-a%b)%b +} + +// cbcMode is an interface for block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// decrypt authenticates and decrypts the record if protection is active at +// this stage. The returned plaintext might overlap with the input. +func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { + var plaintext []byte + typ := recordType(record[0]) + payload := record[recordHeaderLen:] + + // In TLS 1.3, change_cipher_spec messages are to be ignored without being + // decrypted. See RFC 8446, Appendix D.4. + if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { + return payload, typ, nil + } + + paddingGood := byte(255) + paddingLen := 0 + + explicitNonceLen := hc.explicitNonceLen() + + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case aead: + if len(payload) < explicitNonceLen { + return nil, 0, alertBadRecordMAC + } + nonce := payload[:explicitNonceLen] + if len(nonce) == 0 { + nonce = hc.seq[:] + } + payload = payload[explicitNonceLen:] + + var additionalData []byte + if hc.version == VersionTLS13 { + additionalData = record[:recordHeaderLen] + } else { + additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:3]...) + n := len(payload) - c.Overhead() + additionalData = append(additionalData, byte(n>>8), byte(n)) + } + + var err error + plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) + if err != nil { + return nil, 0, alertBadRecordMAC + } + case cbcMode: + blockSize := c.BlockSize() + minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) + if len(payload)%blockSize != 0 || len(payload) < minPayload { + return nil, 0, alertBadRecordMAC + } + + if explicitNonceLen > 0 { + c.SetIV(payload[:explicitNonceLen]) + payload = payload[explicitNonceLen:] + } + c.CryptBlocks(payload, payload) + + // In a limited attempt to protect against CBC padding oracles like + // Lucky13, the data past paddingLen (which is secret) is passed to + // the MAC function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC roughly constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write, modulo cache effects. + paddingLen, paddingGood = extractPadding(payload) + default: + panic("unknown cipher type") + } + + if hc.version == VersionTLS13 { + if typ != recordTypeApplicationData { + return nil, 0, alertUnexpectedMessage + } + if len(plaintext) > maxPlaintext+1 { + return nil, 0, alertRecordOverflow + } + // Remove padding and find the ContentType scanning from the end. + for i := len(plaintext) - 1; i >= 0; i-- { + if plaintext[i] != 0 { + typ = recordType(plaintext[i]) + plaintext = plaintext[:i] + break + } + if i == 0 { + return nil, 0, alertUnexpectedMessage + } + } + } + } else { + plaintext = payload + } + + if hc.mac != nil { + macSize := hc.mac.Size() + if len(payload) < macSize { + return nil, 0, alertBadRecordMAC + } + + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } + record[3] = byte(n >> 8) + record[4] = byte(n) + remoteMAC := payload[n : n+macSize] + localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + + // This is equivalent to checking the MACs and paddingGood + // separately, but in constant-time to prevent distinguishing + // padding failures from MAC failures. Depending on what value + // of paddingLen was returned on bad padding, distinguishing + // bad MAC from bad padding can lead to an attack. + // + // See also the logic at the end of extractPadding. + macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) + if macAndPaddingGood != 1 { + return nil, 0, alertBadRecordMAC + } + + plaintext = payload[:n] + } + + hc.incSeq() + return plaintext, typ, nil +} + +func (c *Conn) setAlternativeRecordLayer() { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + c.in.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetReadKey + c.out.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetWriteKey + } +} + +// sliceForAppend extends the input slice by n bytes. head is the full extended +// slice, while tail is the appended part. If the original slice has sufficient +// capacity no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and +// appends it to record, which must already contain the record header. +func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { + if hc.cipher == nil { + return append(record, payload...), nil + } + + var explicitNonce []byte + if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { + record, explicitNonce = sliceForAppend(record, explicitNonceLen) + if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { + // The AES-GCM construction in TLS has an explicit nonce so that the + // nonce can be random. However, the nonce is only 8 bytes which is + // too small for a secure, random nonce. Therefore we use the + // sequence number as the nonce. The 3DES-CBC construction also has + // an 8 bytes nonce but its nonces must be unpredictable (see RFC + // 5246, Appendix F.3), forcing us to use randomness. That's not + // 3DES' biggest problem anyway because the birthday bound on block + // collision is reached first due to its similarly small block size + // (see the Sweet32 attack). + copy(explicitNonce, hc.seq[:]) + } else { + if _, err := io.ReadFull(rand, explicitNonce); err != nil { + return nil, err + } + } + } + + var dst []byte + switch c := hc.cipher.(type) { + case cipher.Stream: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + record, dst = sliceForAppend(record, len(payload)+len(mac)) + c.XORKeyStream(dst[:len(payload)], payload) + c.XORKeyStream(dst[len(payload):], mac) + case aead: + nonce := explicitNonce + if len(nonce) == 0 { + nonce = hc.seq[:] + } + + if hc.version == VersionTLS13 { + record = append(record, payload...) + + // Encrypt the actual ContentType and replace the plaintext one. + record = append(record, record[0]) + record[0] = byte(recordTypeApplicationData) + + n := len(payload) + 1 + c.Overhead() + record[3] = byte(n >> 8) + record[4] = byte(n) + + record = c.Seal(record[:recordHeaderLen], + nonce, record[recordHeaderLen:], record[:recordHeaderLen]) + } else { + additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:recordHeaderLen]...) + record = c.Seal(record, nonce, payload, additionalData) + } + case cbcMode: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + blockSize := c.BlockSize() + plaintextLen := len(payload) + len(mac) + paddingLen := blockSize - plaintextLen%blockSize + record, dst = sliceForAppend(record, plaintextLen+paddingLen) + copy(dst, payload) + copy(dst[len(payload):], mac) + for i := plaintextLen; i < len(dst); i++ { + dst[i] = byte(paddingLen - 1) + } + if len(explicitNonce) > 0 { + c.SetIV(explicitNonce) + } + c.CryptBlocks(dst, dst) + default: + panic("unknown cipher type") + } + + // Update length to include nonce, MAC and any block padding needed. + n := len(record) - recordHeaderLen + record[3] = byte(n >> 8) + record[4] = byte(n) + hc.incSeq() + + return record, nil +} + +// RecordHeaderError is returned when a TLS record header is invalid. +type RecordHeaderError struct { + // Msg contains a human readable string that describes the error. + Msg string + // RecordHeader contains the five bytes of TLS record header that + // triggered the error. + RecordHeader [5]byte + // Conn provides the underlying net.Conn in the case that a client + // sent an initial handshake that didn't look like TLS. + // It is nil if there's already been a handshake or a TLS alert has + // been written to the connection. + Conn net.Conn +} + +func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } + +func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { + err.Msg = msg + err.Conn = conn + copy(err.RecordHeader[:], c.rawInput.Bytes()) + return err +} + +func (c *Conn) readRecord() error { + return c.readRecordOrCCS(false) +} + +func (c *Conn) readChangeCipherSpec() error { + return c.readRecordOrCCS(true) +} + +// readRecordOrCCS reads one or more TLS records from the connection and +// updates the record layer state. Some invariants: +// * c.in must be locked +// * c.input must be empty +// During the handshake one and only one of the following will happen: +// - c.hand grows +// - c.in.changeCipherSpec is called +// - an error is returned +// After the handshake one and only one of the following will happen: +// - c.hand grows +// - c.input is set +// - an error is returned +func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { + if c.in.err != nil { + return c.in.err + } + handshakeComplete := c.handshakeComplete() + + // This function modifies c.rawInput, which owns the c.input memory. + if c.input.Len() != 0 { + return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data")) + } + c.input.Reset(nil) + + // Read header, payload. + if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { + // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify + // is an error, but popular web sites seem to do this, so we accept it + // if and only if at the record boundary. + if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 { + err = io.EOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + hdr := c.rawInput.Bytes()[:recordHeaderLen] + typ := recordType(hdr[0]) + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + if !handshakeComplete && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) + } + + vers := uint16(hdr[1])<<8 | uint16(hdr[2]) + n := int(hdr[3])<<8 | int(hdr[4]) + if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { + c.sendAlert(alertProtocolVersion) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if !c.haveVers { + // First message, be extra suspicious: this might not be a TLS + // client. Bail out before reading a full 'body', if possible. + // The current max version is 3.3 so if the version is >= 16.0, + // it's probably not real. + if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { + return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) + } + } + if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + + // Process message. + record := c.rawInput.Next(recordHeaderLen + n) + data, typ, err := c.in.decrypt(record) + if err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + if len(data) > maxPlaintext { + return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) + } + + // Application Data messages are always protected. + if c.in.cipher == nil && typ == recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { + // This is a state-advancing message: reset the retry count. + c.retryCount = 0 + } + + // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. + if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + switch typ { + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + + case recordTypeAlert: + if len(data) != 2 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if alert(data[1]) == alertCloseNotify { + return c.in.setErrorLocked(io.EOF) + } + if c.vers == VersionTLS13 { + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + } + switch data[0] { + case alertLevelWarning: + // Drop the record on the floor and retry. + return c.retryReadRecord(expectChangeCipherSpec) + case alertLevelError: + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if len(data) != 1 || data[0] != 1 { + return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) + } + // Handshake messages are not allowed to fragment across the CCS. + if c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // In TLS 1.3, change_cipher_spec records are ignored until the + // Finished. See RFC 8446, Appendix D.4. Note that according to Section + // 5, a server can send a ChangeCipherSpec before its ServerHello, when + // c.vers is still unset. That's not useful though and suspicious if the + // server then selects a lower protocol version, so don't allow that. + if c.vers == VersionTLS13 { + return c.retryReadRecord(expectChangeCipherSpec) + } + if !expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if err := c.in.changeCipherSpec(); err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + + case recordTypeApplicationData: + if !handshakeComplete || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // Some OpenSSL servers send empty records in order to randomize the + // CBC IV. Ignore a limited number of empty records. + if len(data) == 0 { + return c.retryReadRecord(expectChangeCipherSpec) + } + // Note that data is owned by c.rawInput, following the Next call above, + // to avoid copying the plaintext. This is safe because c.rawInput is + // not read from or written to until c.input is drained. + c.input.Reset(data) + + case recordTypeHandshake: + if len(data) == 0 || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.hand.Write(data) + } + + return nil +} + +// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like +// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. +func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many ignored records")) + } + return c.readRecordOrCCS(expectChangeCipherSpec) +} + +// atLeastReader reads from R, stopping with EOF once at least N bytes have been +// read. It is different from an io.LimitedReader in that it doesn't cut short +// the last Read call, and in that it considers an early EOF an error. +type atLeastReader struct { + R io.Reader + N int64 +} + +func (r *atLeastReader) Read(p []byte) (int, error) { + if r.N <= 0 { + return 0, io.EOF + } + n, err := r.R.Read(p) + r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 + if r.N > 0 && err == io.EOF { + return n, io.ErrUnexpectedEOF + } + if r.N <= 0 && err == nil { + return n, io.EOF + } + return n, err +} + +// readFromUntil reads from r into c.rawInput until c.rawInput contains +// at least n bytes or else returns an error. +func (c *Conn) readFromUntil(r io.Reader, n int) error { + if c.rawInput.Len() >= n { + return nil + } + needs := n - c.rawInput.Len() + // There might be extra input waiting on the wire. Make a best effort + // attempt to fetch it so that it can be used in (*Conn).Read to + // "predict" closeNotify alerts. + c.rawInput.Grow(needs + bytes.MinRead) + _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)}) + return err +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlertLocked(err alert) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr + } + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlert(err alert) error { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + c.extraConfig.AlternativeRecordLayer.SendAlert(uint8(err)) + return &net.OpError{Op: "local error", Err: err} + } + + c.out.Lock() + defer c.out.Unlock() + return c.sendAlertLocked(err) +} + +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the +// next application data record. There is the following trade-off: +// +// - For latency-sensitive applications, such as web browsing, each TLS +// record should fit in one TCP segment. +// - For throughput-sensitive applications, such as large file transfers, +// larger TLS records better amortize framing and encryption overheads. +// +// A simple heuristic that works well in practice is to use small records for +// the first 1MB of data, then use larger records for subsequent data, and +// reset back to smaller records after the connection becomes idle. See "High +// Performance Web Networking", Chapter 4, or: +// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ +// +// In the interests of simplicity and determinism, this code does not attempt +// to reset the record size once the connection is idle, however. +func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { + if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { + return maxPlaintext + } + + if c.bytesSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() + if c.out.cipher != nil { + switch ciph := c.out.cipher.(type) { + case cipher.Stream: + payloadBytes -= c.out.mac.Size() + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The MAC is appended before padding so affects the + // payload size directly. + payloadBytes -= c.out.mac.Size() + default: + panic("unknown cipher type") + } + } + if c.vers == VersionTLS13 { + payloadBytes-- // encrypted ContentType + } + + // Allow packet growth in arithmetic progression up to max. + pkt := c.packetsSent + c.packetsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} + +func (c *Conn) write(data []byte) (int, error) { + if c.buffering { + c.sendBuf = append(c.sendBuf, data...) + return len(data), nil + } + + n, err := c.conn.Write(data) + c.bytesSent += int64(n) + return n, err +} + +func (c *Conn) flush() (int, error) { + if len(c.sendBuf) == 0 { + return 0, nil + } + + n, err := c.conn.Write(c.sendBuf) + c.bytesSent += int64(n) + c.sendBuf = nil + c.buffering = false + return n, err +} + +// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. +var outBufPool = sync.Pool{ + New: func() interface{} { + return new([]byte) + }, +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + outBufPtr := outBufPool.Get().(*[]byte) + outBuf := *outBufPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *outBufPtr = outBuf + outBufPool.Put(outBufPtr) + }() + + var n int + for len(data) > 0 { + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { + m = maxPayload + } + + _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } else if vers == VersionTLS13 { + // TLS 1.3 froze the record layer version to 1.2. + // See RFC 8446, Section 5.1. + vers = VersionTLS12 + } + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) + + var err error + outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) + if err != nil { + return n, err + } + if _, err := c.write(outBuf); err != nil { + return n, err + } + n += m + data = data[m:] + } + + if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) + } + } + + return n, nil +} + +// writeRecord writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + if typ == recordTypeChangeCipherSpec { + return len(data), nil + } + return c.extraConfig.AlternativeRecordLayer.WriteRecord(data) + } + + c.out.Lock() + defer c.out.Unlock() + + return c.writeRecordLocked(typ, data) +} + +// readHandshake reads the next handshake message from +// the record layer. +func (c *Conn) readHandshake() (interface{}, error) { + var data []byte + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + var err error + data, err = c.extraConfig.AlternativeRecordLayer.ReadHandshakeMessage() + if err != nil { + return nil, err + } + } else { + for c.hand.Len() < 4 { + if err := c.readRecord(); err != nil { + return nil, err + } + } + + data = c.hand.Bytes() + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + c.sendAlertLocked(alertInternalError) + return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) + } + for c.hand.Len() < 4+n { + if err := c.readRecord(); err != nil { + return nil, err + } + } + data = c.hand.Next(4 + n) + } + var m handshakeMessage + switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) + case typeClientHello: + m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeNewSessionTicket: + if c.vers == VersionTLS13 { + m = new(newSessionTicketMsgTLS13) + } else { + m = new(newSessionTicketMsg) + } + case typeCertificate: + if c.vers == VersionTLS13 { + m = new(certificateMsgTLS13) + } else { + m = new(certificateMsg) + } + case typeCertificateRequest: + if c.vers == VersionTLS13 { + m = new(certificateRequestMsgTLS13) + } else { + m = &certificateRequestMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + } + case typeCertificateStatus: + m = new(certificateStatusMsg) + case typeServerKeyExchange: + m = new(serverKeyExchangeMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeCertificateVerify: + m = &certificateVerifyMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + case typeFinished: + m = new(finishedMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) + case typeKeyUpdate: + m = new(keyUpdateMsg) + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshalers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if !m.unmarshal(data) { + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + return m, nil +} + +var ( + errShutdown = errors.New("tls: protocol is shutdown") +) + +// Write writes data to the connection. +// +// As Write calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Write is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Write(b []byte) (int, error) { + // interlock with Close below + for { + x := atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return 0, net.ErrClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { + break + } + } + defer atomic.AddInt32(&c.activeCall, -2) + + if err := c.Handshake(); err != nil { + return 0, err + } + + c.out.Lock() + defer c.out.Unlock() + + if err := c.out.err; err != nil { + return 0, err + } + + if !c.handshakeComplete() { + return 0, alertInternalError + } + + if c.closeNotifySent { + return 0, errShutdown + } + + // TLS 1.0 is susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the IV. + // + // https://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // https://www.imperialviolet.org/2012/01/15/beastfollowup.html + + var m int + if len(b) > 1 && c.vers == VersionTLS10 { + if _, ok := c.out.cipher.(cipher.BlockMode); ok { + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.out.setErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecordLocked(recordTypeApplicationData, b) + return n + m, c.out.setErrorLocked(err) +} + +// handleRenegotiation processes a HelloRequest handshake message. +func (c *Conn) handleRenegotiation() error { + if c.vers == VersionTLS13 { + return errors.New("tls: internal error: unexpected renegotiation") + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + helloReq, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(helloReq, msg) + } + + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + atomic.StoreUint32(&c.handshakeStatus, 0) + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + +func (c *Conn) HandlePostHandshakeMessage() error { + return c.handlePostHandshakeMessage() +} + +// handlePostHandshakeMessage processes a handshake message arrived after the +// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. +func (c *Conn) handlePostHandshakeMessage() error { + if c.vers != VersionTLS13 { + return c.handleRenegotiation() + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) + } + + switch msg := msg.(type) { + case *newSessionTicketMsgTLS13: + return c.handleNewSessionTicket(msg) + case *keyUpdateMsg: + return c.handleKeyUpdate(msg) + default: + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) + } +} + +func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil { + return c.in.setErrorLocked(c.sendAlert(alertInternalError)) + } + + newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) + c.in.setTrafficSecret(cipherSuite, newSecret) + + if keyUpdate.updateRequested { + c.out.Lock() + defer c.out.Unlock() + + msg := &keyUpdateMsg{} + _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + if err != nil { + // Surface the error at the next write. + c.out.setErrorLocked(err) + return nil + } + + newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) + c.out.setTrafficSecret(cipherSuite, newSecret) + } + + return nil +} + +// Read reads data from the connection. +// +// As Read calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Read is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Read(b []byte) (int, error) { + if err := c.Handshake(); err != nil { + return 0, err + } + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return 0, nil + } + + c.in.Lock() + defer c.in.Unlock() + + for c.input.Len() == 0 { + if err := c.readRecord(); err != nil { + return 0, err + } + for c.hand.Len() > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + return 0, err + } + } + } + + n, _ := c.input.Read(b) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && + recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { + if err := c.readRecord(); err != nil { + return n, err // will be io.EOF on closeNotify + } + } + + return n, nil +} + +// Close closes the connection. +func (c *Conn) Close() error { + // Interlock with Conn.Write above. + var x int32 + for { + x = atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return net.ErrClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { + break + } + } + if x != 0 { + // io.Writer and io.Closer should not be used concurrently. + // If Close is called while a Write is currently in-flight, + // interpret that as a sign that this Close is really just + // being used to break the Write and/or clean up resources and + // avoid sending the alertCloseNotify, which may block + // waiting on handshakeMutex or the c.out mutex. + return c.conn.Close() + } + + var alertErr error + if c.handshakeComplete() { + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } + } + + if err := c.conn.Close(); err != nil { + return err + } + return alertErr +} + +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + if !c.handshakeComplete() { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + c.out.Lock() + defer c.out.Unlock() + + if !c.closeNotifySent { + // Set a Write Deadline to prevent possibly blocking forever. + c.SetWriteDeadline(time.Now().Add(time.Second * 5)) + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + // Any subsequent writes will fail. + c.SetWriteDeadline(time.Now()) + } + return c.closeNotifyErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first Read or Write will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// HandshakeContext or the Dialer's DialContext method instead. +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + // Fast sync/atomic-based exit if there is no handshake in flight and the + // last one succeeded without an error. Avoids the expensive context setup + // and mutex for most Read and Write calls. + if c.handshakeComplete() { + return nil + } + + handshakeCtx, cancel := context.WithCancel(ctx) + // Note: defer this before starting the "interrupter" goroutine + // so that we can tell the difference between the input being canceled and + // this cancellation. In the former case, we need to close the connection. + defer cancel() + + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // Return context error to user. + ret = ctxErr + } + }() + go func() { + select { + case <-handshakeCtx.Done(): + // Close the connection, discarding the error + _ = c.conn.Close() + interruptRes <- handshakeCtx.Err() + case <-done: + interruptRes <- nil + } + }() + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeComplete() { + return nil + } + + c.in.Lock() + defer c.in.Unlock() + + c.handshakeErr = c.handshakeFn(handshakeCtx) + if c.handshakeErr == nil { + c.handshakes++ + } else { + // If an error occurred during the handshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + if c.handshakeErr == nil && !c.handshakeComplete() { + c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") + } + if c.handshakeErr != nil && c.handshakeComplete() { + panic("tls: internal error: handshake returned an error but is marked successful") + } + + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return c.connectionStateLocked() +} + +// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection. +func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return ConnectionStateWith0RTT{ + ConnectionState: c.connectionStateLocked(), + Used0RTT: c.used0RTT, + } +} + +func (c *Conn) connectionStateLocked() ConnectionState { + var state connectionState + state.HandshakeComplete = c.handshakeComplete() + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = true + state.ServerName = c.serverName + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } + } + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } + return toConnectionState(state) +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) VerifyHostname(host string) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.isClient { + return errors.New("tls: VerifyHostname called on TLS server connection") + } + if !c.handshakeComplete() { + return errors.New("tls: handshake has not yet been performed") + } + if len(c.verifiedChains) == 0 { + return errors.New("tls: handshake did not verify certificate chain") + } + return c.peerCertificates[0].VerifyHostname(host) +} + +func (c *Conn) handshakeComplete() bool { + return atomic.LoadUint32(&c.handshakeStatus) == 1 +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/cpu.go b/vendor/github.com/marten-seemann/qtls-go1-17/cpu.go new file mode 100644 index 00000000..12194508 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/cpu.go @@ -0,0 +1,22 @@ +//go:build !js +// +build !js + +package qtls + +import ( + "runtime" + + "golang.org/x/sys/cpu" +) + +var ( + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && + (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || + runtime.GOARCH == "arm64" && hasGCMAsmARM64 || + runtime.GOARCH == "s390x" && hasGCMAsmS390X +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/cpu_other.go b/vendor/github.com/marten-seemann/qtls-go1-17/cpu_other.go new file mode 100644 index 00000000..33f7d219 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/cpu_other.go @@ -0,0 +1,12 @@ +//go:build js +// +build js + +package qtls + +var ( + hasGCMAsmAMD64 = false + hasGCMAsmARM64 = false + hasGCMAsmS390X = false + + hasAESGCMHardwareSupport = false +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client.go new file mode 100644 index 00000000..ac5d0b4c --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client.go @@ -0,0 +1,1111 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "strings" + "sync/atomic" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +const clientSessionStateVersion = 1 + +type clientHandshakeState struct { + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + finishedHash finishedHash + masterSecret []byte + session *clientSessionState +} + +func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { + config := c.config + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + + nextProtosLength := 0 + for _, proto := range config.NextProtos { + if l := len(proto); l == 0 || l > 255 { + return nil, nil, errors.New("tls: invalid NextProtos value") + } else { + nextProtosLength += 1 + l + } + } + if nextProtosLength > 0xffff { + return nil, nil, errors.New("tls: NextProtos values too large") + } + + var supportedVersions []uint16 + var clientHelloVersion uint16 + if c.extraConfig.usesAlternativeRecordLayer() { + if config.maxSupportedVersion() < VersionTLS13 { + return nil, nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") + } + // Only offer TLS 1.3 when QUIC is used. + supportedVersions = []uint16{VersionTLS13} + clientHelloVersion = VersionTLS13 + } else { + supportedVersions = config.supportedVersions() + if len(supportedVersions) == 0 { + return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") + } + clientHelloVersion = config.maxSupportedVersion() + } + + // The version at the beginning of the ClientHello was capped at TLS 1.2 + // for compatibility reasons. The supported_versions extension is used + // to negotiate versions now. See RFC 8446, Section 4.2.1. + if clientHelloVersion > VersionTLS12 { + clientHelloVersion = VersionTLS12 + } + + hello := &clientHelloMsg{ + vers: clientHelloVersion, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + ocspStapling: true, + scts: true, + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + secureRenegotiationSupported: true, + alpnProtocols: config.NextProtos, + supportedVersions: supportedVersions, + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } + + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport { + preferenceOrder = cipherSuitesPreferenceOrderNoAES + } + configCipherSuites := config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(configCipherSuites)) + + for _, suiteId := range preferenceOrder { + suite := mutualCipherSuite(configCipherSuites, suiteId) + if suite == nil { + continue + } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + continue + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) + } + + _, err := io.ReadFull(config.rand(), hello.random) + if err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + // A random session ID is used to detect when the server accepted a ticket + // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as + // a compatibility measure (see RFC 8446, Section 4.1.2). + if c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + hello.sessionId = make([]byte, 32) + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + } + + if hello.vers >= VersionTLS12 { + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + var params ecdheParameters + if hello.supportedVersions[0] == VersionTLS13 { + var suites []uint16 + for _, suiteID := range configCipherSuites { + for _, suite := range cipherSuitesTLS13 { + if suite.id == suiteID { + suites = append(suites, suiteID) + } + } + } + if len(suites) > 0 { + hello.cipherSuites = suites + } else { + if hasAESGCMHardwareSupport { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) + } else { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) + } + } + + curveID := config.curvePreferences()[0] + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err = generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + if hello.supportedVersions[0] == VersionTLS13 && c.extraConfig != nil && c.extraConfig.GetExtensions != nil { + hello.additionalExtensions = c.extraConfig.GetExtensions(typeClientHello) + } + + return hello, params, nil +} + +func (c *Conn) clientHandshake(ctx context.Context) (err error) { + if c.config == nil { + c.config = fromConfig(defaultConfig()) + } + c.setAlternativeRecordLayer() + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, ecdheParams, err := c.makeClientHello() + if err != nil { + return err + } + c.serverName = hello.serverName + + cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + if cacheKey != "" && session != nil { + var deletedTicket bool + if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT { + // don't reuse a session ticket that enabled 0-RTT + c.config.ClientSessionCache.Put(cacheKey, nil) + deletedTicket = true + + if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil { + h := suite.hash.New() + h.Write(hello.marshal()) + clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h) + c.out.exportKey(Encryption0RTT, suite, clientEarlySecret) + if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + } + } + if !deletedTicket { + defer func() { + // If we got a handshake failure when resuming a session, throw away + // the session ticket. See RFC 5077, Section 3.2. + // + // RFC 8446 makes no mention of dropping tickets on failure, but it + // does require servers to abort on invalid binders, so we need to + // delete tickets to recover from a corrupted PSK. + if err != nil { + c.config.ClientSessionCache.Put(cacheKey, nil) + } + }() + } + } + + if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + + if err := c.pickTLSVersion(serverHello); err != nil { + return err + } + + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion() + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + + if c.vers == VersionTLS13 { + hs := &clientHandshakeStateTLS13{ + c: c, + ctx: ctx, + serverHello: serverHello, + hello: hello, + ecdheParams: ecdheParams, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + } + + // In TLS 1.3, session tickets are delivered after the handshake. + return hs.handshake() + } + + hs := &clientHandshakeState{ + c: c, + ctx: ctx, + serverHello: serverHello, + hello: hello, + session: session, + } + + if err := hs.handshake(); err != nil { + return err + } + + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one. + if cacheKey != "" && hs.session != nil && session != hs.session { + c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session)) + } + + return nil +} + +// extract the app data saved in the session.nonce, +// and set the session.nonce to the actual nonce value +func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max early data */, []byte /* app data */, bool /* ok */) { + s := cryptobyte.String(session.nonce) + var version uint16 + if !s.ReadUint16(&version) { + return 0, nil, false + } + if version != clientSessionStateVersion { + return 0, nil, false + } + var maxEarlyData uint32 + if !s.ReadUint32(&maxEarlyData) { + return 0, nil, false + } + var appData []byte + if !readUint16LengthPrefixed(&s, &appData) { + return 0, nil, false + } + var nonce []byte + if !readUint16LengthPrefixed(&s, &nonce) { + return 0, nil, false + } + session.nonce = nonce + return maxEarlyData, appData, true +} + +func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, + session *clientSessionState, earlySecret, binderKey []byte) { + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return "", nil, nil, nil + } + + hello.ticketSupported = true + + if hello.supportedVersions[0] == VersionTLS13 { + // Require DHE on resumption as it guarantees forward secrecy against + // compromise of the session ticket key. See RFC 8446, Section 4.2.9. + hello.pskModes = []uint8{pskModeDHE} + } + + // Session resumption is not allowed if renegotiating because + // renegotiation is primarily used to allow a client to send a client + // certificate, which would be skipped if session resumption occurred. + if c.handshakes != 0 { + return "", nil, nil, nil + } + + // Try to resume a previously negotiated TLS session, if available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + sess, ok := c.config.ClientSessionCache.Get(cacheKey) + if !ok || sess == nil { + return cacheKey, nil, nil, nil + } + session = fromClientSessionState(sess) + + var appData []byte + var maxEarlyData uint32 + if session.vers == VersionTLS13 { + var ok bool + maxEarlyData, appData, ok = c.decodeSessionState(session) + if !ok { // delete it, if parsing failed + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + } + + // Check that version used for the previous session is still valid. + versOk := false + for _, v := range hello.supportedVersions { + if v == session.vers { + versOk = true + break + } + } + if !versOk { + return cacheKey, nil, nil, nil + } + + // Check that the cached server certificate is not expired, and that it's + // valid for the ServerName. This should be ensured by the cache key, but + // protect the application from a faulty ClientSessionCache implementation. + if !c.config.InsecureSkipVerify { + if len(session.verifiedChains) == 0 { + // The original connection had InsecureSkipVerify, while this doesn't. + return cacheKey, nil, nil, nil + } + serverCert := session.serverCertificates[0] + if c.config.time().After(serverCert.NotAfter) { + // Expired certificate, delete the entry. + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { + return cacheKey, nil, nil, nil + } + } + + if session.vers != VersionTLS13 { + // In TLS 1.2 the cipher suite must match the resumed session. Ensure we + // are still offering it. + if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { + return cacheKey, nil, nil, nil + } + + hello.sessionTicket = session.sessionTicket + return + } + + // Check that the session ticket is not expired. + if c.config.time().After(session.useBy) { + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + + // In TLS 1.3 the KDF hash must match the resumed session. Ensure we + // offer at least one cipher suite with that hash. + cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) + if cipherSuite == nil { + return cacheKey, nil, nil, nil + } + cipherSuiteOk := false + for _, offeredID := range hello.cipherSuites { + offeredSuite := cipherSuiteTLS13ByID(offeredID) + if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return cacheKey, nil, nil, nil + } + + // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. + ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) + identity := pskIdentity{ + label: session.sessionTicket, + obfuscatedTicketAge: ticketAge + session.ageAdd, + } + hello.pskIdentities = []pskIdentity{identity} + hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} + + // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. + psk := cipherSuite.expandLabel(session.masterSecret, "resumption", + session.nonce, cipherSuite.hash.Size()) + earlySecret = cipherSuite.extract(psk, nil) + binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) + if c.extraConfig != nil { + hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0 + } + transcript := cipherSuite.hash.New() + transcript.Write(hello.marshalWithoutBinders()) + pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} + hello.updateBinders(pskBinders) + + if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { + c.extraConfig.SetAppDataFromSessionState(appData) + } + return +} + +func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { + peerVersion := serverHello.vers + if serverHello.supportedVersion != 0 { + peerVersion = serverHello.supportedVersion + } + + vers, ok := c.config.mutualVersion([]uint16{peerVersion}) + if !ok { + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) + } + + c.vers = vers + c.haveVers = true + c.in.version = vers + c.out.version = vers + + return nil +} + +// Does the handshake, either a full one or resumes old session. Requires hs.c, +// hs.hello, hs.serverHello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() error { + c := hs.c + + isResume, err := hs.processServerHello() + if err != nil { + return err + } + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + + // No signatures of the handshake are needed in a resumption. + // Otherwise, in a full handshake, if we don't have any certificates + // configured then we will never send a CertificateVerify message and + // thus no signatures are needed in that case either. + if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { + hs.finishedHash.discardHandshakeBuffer() + } + + hs.finishedHash.Write(hs.hello.marshal()) + hs.finishedHash.Write(hs.serverHello.marshal()) + + c.buffering = true + c.didResume = isResume + if isResume { + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = false + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } else { + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = true + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + + hs.c.cipherSuite = hs.suite.id + return nil +} + +func (hs *clientHandshakeState) doFullHandshake() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + certMsg, ok := msg.(*certificateMsg) + if !ok || len(certMsg.certificates) == 0 { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + msg, err = c.readHandshake() + if err != nil { + return err + } + + cs, ok := msg.(*certificateStatusMsg) + if ok { + // RFC4366 on Certificate Status Request: + // The server MAY return a "certificate_status" message. + + if !hs.serverHello.ocspStapling { + // If a server returns a "CertificateStatus" message, then the + // server MUST have included an extension of type "status_request" + // with empty "extension_data" in the extended server hello. + + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received unexpected CertificateStatus message") + } + hs.finishedHash.Write(cs.marshal()) + + c.ocspResponse = cs.response + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + if c.handshakes == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := c.verifyServerCertificate(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + + keyAgreement := hs.suite.ka(c.vers) + + skx, ok := msg.(*serverKeyExchangeMsg) + if ok { + hs.finishedHash.Write(skx.marshal()) + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) + if err != nil { + c.sendAlert(alertUnexpectedMessage) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + var chainToSend *Certificate + var certRequested bool + certReq, ok := msg.(*certificateRequestMsg) + if ok { + certRequested = true + hs.finishedHash.Write(certReq.marshal()) + + cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) + if chainToSend, err = c.getClientCertificate(cri); err != nil { + c.sendAlert(alertInternalError) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + shd, ok := msg.(*serverHelloDoneMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) + } + hs.finishedHash.Write(shd.marshal()) + + // If the server requested a certificate then we have to send a + // Certificate message, even if it's empty because we don't have a + // certificate to send. + if certRequested { + certMsg = new(certificateMsg) + certMsg.certificates = chainToSend.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + } + + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ckx != nil { + hs.finishedHash.Write(ckx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + return err + } + } + + if chainToSend != nil && len(chainToSend.Certificate) > 0 { + certVerify := &certificateVerifyMsg{} + + key, ok := chainToSend.PrivateKey.(crypto.Signer) + if !ok { + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + certVerify.hasSignatureAlgorithm = true + certVerify.signatureAlgorithm = signatureAlgorithm + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.finishedHash.Write(certVerify.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + return err + } + } + + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to write to key log: " + err.Error()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher interface{} + var clientHash, serverHash hash.Hash + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) + return nil +} + +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} + +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if err := hs.pickCipherSuite(); err != nil { + return false, err + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") + } + + if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { + c.secureRenegotiation = true + if len(hs.serverHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: initial handshake had non-empty renegotiation extension") + } + } + + if c.handshakes > 0 && c.secureRenegotiation { + var expectedSecureRenegotiation [24]byte + copy(expectedSecureRenegotiation[:], c.clientFinished[:]) + copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) + if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: incorrect renegotiation extension contents") + } + } + + if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil { + c.sendAlert(alertUnsupportedExtension) + return false, err + } + c.clientProtocol = hs.serverHello.alpnProtocol + + c.scts = hs.serverHello.scts + + if !hs.serverResumedSession() { + return false, nil + } + + if hs.session.vers != c.vers { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different version") + } + + if hs.session.cipherSuite != hs.suite.id { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different cipher suite") + } + + // Restore masterSecret, peerCerts, and ocspResponse from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + // Let the ServerHello SCTs override the session SCTs from the original + // connection, if any are provided + if len(c.scts) == 0 && len(hs.session.scts) != 0 { + c.scts = hs.session.scts + } + + return true, nil +} + +// checkALPN ensure that the server's choice of ALPN protocol is compatible with +// the protocols that we advertised in the Client Hello. +func checkALPN(clientProtos []string, serverProto string) error { + if serverProto == "" { + return nil + } + if len(clientProtos) == 0 { + return errors.New("tls: server advertised unrequested ALPN extension") + } + for _, proto := range clientProtos { + if proto == serverProto { + return nil + } + } + return errors.New("tls: server selected unadvertised ALPN protocol") +} + +func (hs *clientHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + + verify := hs.finishedHash.serverSum(hs.masterSecret) + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") + } + hs.finishedHash.Write(serverFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + if !hs.serverHello.ticketSupported { + return nil + } + + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + hs.finishedHash.Write(sessionTicketMsg.marshal()) + + hs.session = &clientSessionState{ + sessionTicket: sessionTicketMsg.ticket, + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + return nil +} + +func (hs *clientHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + copy(out, finished.verifyData) + return nil +} + +// verifyServerCertificate parses and verifies the provided chain, setting +// c.verifiedChains and c.peerCertificates or sending the appropriate alert. +func (c *Conn) verifyServerCertificate(certificates [][]byte) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + if !c.config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + var err error + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + switch certs[0].PublicKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: + break + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) + } + + c.peerCertificates = certs + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS +// <= 1.2 CertificateRequest, making an effort to fill in missing information. +func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { + cri := &certificateRequestInfo{ + AcceptableCAs: certReq.certificateAuthorities, + Version: vers, + ctx: ctx, + } + + var rsaAvail, ecAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case certTypeRSASign: + rsaAvail = true + case certTypeECDSASign: + ecAvail = true + } + } + + if !certReq.hasSignatureAlgorithm { + // Prior to TLS 1.2, signature schemes did not exist. In this case we + // make up a list based on the acceptable certificate types, to help + // GetClientCertificate and SupportsCertificate select the right certificate. + // The hash part of the SignatureScheme is a lie here, because + // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. + switch { + case rsaAvail && ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case rsaAvail: + cri.SignatureSchemes = []SignatureScheme{ + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + } + } + return toCertificateRequestInfo(cri) + } + + // Filter the signature schemes based on the certificate types. + // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). + cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) + for _, sigScheme := range certReq.supportedSignatureAlgorithms { + sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) + if err != nil { + continue + } + switch sigType { + case signatureECDSA, signatureEd25519: + if ecAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + case signatureRSAPSS, signaturePKCS1v15: + if rsaAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + } + } + + return toCertificateRequestInfo(cri) +} + +func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { + if c.config.GetClientCertificate != nil { + return c.config.GetClientCertificate(cri) + } + + for _, chain := range c.config.Certificates { + if err := cri.SupportsCertificate(&chain); err != nil { + continue + } + return &chain, nil + } + + // No acceptable certificate found. Don't send a certificate. + return new(Certificate), nil +} + +const clientSessionCacheKeyPrefix = "qtls-" + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *config) string { + if len(config.ServerName) > 0 { + return clientSessionCacheKeyPrefix + config.ServerName + } + return clientSessionCacheKeyPrefix + serverAddr.String() +} + +// hostnameInSNI converts name into an appropriate hostname for SNI. +// Literal IP addresses and absolute FQDNs are not permitted as SNI values. +// See RFC 6066, Section 3. +func hostnameInSNI(name string) string { + host := name + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } + if i := strings.LastIndex(host, "%"); i > 0 { + host = host[:i] + } + if net.ParseIP(host) != nil { + return "" + } + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return name +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client_tls13.go new file mode 100644 index 00000000..0de59fc1 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client_tls13.go @@ -0,0 +1,732 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/hmac" + "crypto/rsa" + "encoding/binary" + "errors" + "hash" + "sync/atomic" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +type clientHandshakeStateTLS13 struct { + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheParams ecdheParameters + + session *clientSessionState + earlySecret []byte + binderKey []byte + + certReq *certificateRequestMsgTLS13 + usingPSK bool + sentDummyCCS bool + suite *cipherSuiteTLS13 + transcript hash.Hash + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 +} + +// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. +func (hs *clientHandshakeStateTLS13) handshake() error { + c := hs.c + + // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, + // sections 4.1.2 and 4.1.3. + if c.handshakes > 0 { + c.sendAlert(alertProtocolVersion) + return errors.New("tls: server selected TLS 1.3 in a renegotiation") + } + + // Consistency check on the presence of a keyShare and its parameters. + if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + return c.sendAlert(alertInternalError) + } + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + hs.transcript = hs.suite.hash.New() + hs.transcript.Write(hs.hello.marshal()) + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.processHelloRetryRequest(); err != nil { + return err + } + } + + hs.transcript.Write(hs.serverHello.marshal()) + + c.buffering = true + if err := hs.processServerHello(); err != nil { + return err + } + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.establishHandshakeKeys(); err != nil { + return err + } + if err := hs.readServerParameters(); err != nil { + return err + } + if err := hs.readServerCertificate(); err != nil { + return err + } + if err := hs.readServerFinished(); err != nil { + return err + } + if err := hs.sendClientCertificate(); err != nil { + return err + } + if err := hs.sendClientFinished(); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// checkServerHelloOrHRR does validity checks that apply to both ServerHello and +// HelloRetryRequest messages. It sets hs.suite. +func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { + c := hs.c + + if hs.serverHello.supportedVersion == 0 { + c.sendAlert(alertMissingExtension) + return errors.New("tls: server selected TLS 1.3 using the legacy version field") + } + + if hs.serverHello.supportedVersion != VersionTLS13 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid version after a HelloRetryRequest") + } + + if hs.serverHello.vers != VersionTLS12 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an incorrect legacy version") + } + + if hs.serverHello.ocspStapling || + hs.serverHello.ticketSupported || + hs.serverHello.secureRenegotiationSupported || + len(hs.serverHello.secureRenegotiation) != 0 || + len(hs.serverHello.alpnProtocol) != 0 || + len(hs.serverHello.scts) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") + } + + if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not echo the legacy session ID") + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported compression format") + } + + selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) + if hs.suite != nil && selectedSuite != hs.suite { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server changed cipher suite after a HelloRetryRequest") + } + if selectedSuite == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server chose an unconfigured cipher suite") + } + hs.suite = selectedSuite + c.cipherSuite = hs.suite.id + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and +// resends hs.hello, and reads the new ServerHello into hs.serverHello. +func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. (The idea is that the server might offload transcript + // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + hs.transcript.Write(hs.serverHello.marshal()) + + // The only HelloRetryRequest extensions we support are key_share and + // cookie, and clients must abort the handshake if the HRR would not result + // in any change in the ClientHello. + if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest message") + } + + if hs.serverHello.cookie != nil { + hs.hello.cookie = hs.serverHello.cookie + } + + if hs.serverHello.serverShare.group != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received malformed key_share extension") + } + + // If the server sent a key_share extension selecting a group, ensure it's + // a group we advertised but did not send a key share for, and send a key + // share for it this time. + if curveID := hs.serverHello.selectedGroup; curveID != 0 { + curveOK := false + for _, id := range hs.hello.supportedCurves { + if id == curveID { + curveOK = true + break + } + } + if !curveOK { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + if hs.ecdheParams.CurveID() == curveID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), curveID) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.ecdheParams = params + hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + hs.hello.raw = nil + if len(hs.hello.pskIdentities) > 0 { + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash == hs.suite.hash { + // Update binders and obfuscated_ticket_age. + ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) + hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd + + transcript := hs.suite.hash.New() + transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + transcript.Write(chHash) + transcript.Write(hs.serverHello.marshal()) + transcript.Write(hs.hello.marshalWithoutBinders()) + pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} + hs.hello.updateBinders(pskBinders) + } else { + // Server selected a cipher suite incompatible with the PSK. + hs.hello.pskIdentities = nil + hs.hello.pskBinders = nil + } + } + + if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { + c.extraConfig.Rejected0RTT() + } + hs.hello.earlyData = false // disable 0-RTT + + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + hs.serverHello = serverHello + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) processServerHello() error { + c := hs.c + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: server sent two HelloRetryRequest messages") + } + + if len(hs.serverHello.cookie) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a cookie in a normal ServerHello") + } + + if hs.serverHello.selectedGroup != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: malformed key_share extension") + } + + if hs.serverHello.serverShare.group == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not send a key share") + } + if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + + if !hs.serverHello.selectedIdentityPresent { + return nil + } + + if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK") + } + + if len(hs.hello.pskIdentities) != 1 || hs.session == nil { + return c.sendAlert(alertInternalError) + } + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash != hs.suite.hash { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK and cipher suite pair") + } + + hs.usingPSK = true + c.didResume = true + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + c.scts = hs.session.scts + return nil +} + +func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { + c := hs.c + + sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) + if sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + + earlySecret := hs.earlySecret + if !hs.usingPSK { + earlySecret = hs.suite.extract(nil, nil) + } + handshakeSecret := hs.suite.extract(sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret) + c.out.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(handshakeSecret, "derived", nil)) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerParameters() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(encryptedExtensions, msg) + } + // Notify the caller if 0-RTT was rejected. + if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { + c.extraConfig.Rejected0RTT() + } + c.used0RTT = encryptedExtensions.earlyData + if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil { + hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions) + } + hs.transcript.Write(encryptedExtensions.marshal()) + + if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { + c.sendAlert(alertUnsupportedExtension) + return err + } + c.clientProtocol = encryptedExtensions.alpnProtocol + + if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection { + if len(encryptedExtensions.alpnProtocol) == 0 { + // the server didn't select an ALPN + c.sendAlert(alertNoApplicationProtocol) + return errors.New("ALPN negotiation failed. Server didn't offer any protocols") + } + } + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerCertificate() error { + c := hs.c + + // Either a PSK or a certificate is always used, but not both. + // See RFC 8446, Section 4.1.1. + if hs.usingPSK { + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certReq, ok := msg.(*certificateRequestMsgTLS13) + if ok { + hs.transcript.Write(certReq.marshal()) + + hs.certReq = certReq + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + if len(certMsg.certificate.Certificate) == 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received empty certificates message") + } + hs.transcript.Write(certMsg.marshal()) + + c.scts = certMsg.certificate.SignedCertificateTimestamps + c.ocspResponse = certMsg.certificate.OCSPStaple + + if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + if !hmac.Equal(expectedMAC, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid server finished hash") + } + + hs.transcript.Write(finished.marshal()) + + // Derive secrets that take context through the server Finished. + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionApplication, hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { + c := hs.c + + if hs.certReq == nil { + return nil + } + + cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{ + AcceptableCAs: hs.certReq.certificateAuthorities, + SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, + Version: c.vers, + ctx: hs.ctx, + })) + if err != nil { + return err + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *cert + certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + // If we sent an empty certificate message, skip the CertificateVerify. + if len(cert.Certificate) == 0 { + return nil + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + + certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) + if err != nil { + // getClientCertificate returned a certificate incompatible with the + // CertificateRequestInfo supported signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) + c.out.setTrafficSecret(hs.suite, hs.trafficSecret) + + if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + } + + return nil +} + +func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { + if !c.isClient { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received new session ticket from a client") + } + + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return nil + } + + // See RFC 8446, Section 4.6.1. + if msg.lifetime == 0 { + return nil + } + lifetime := time.Duration(msg.lifetime) * time.Second + if lifetime > maxSessionTicketLifetime { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: received a session ticket with invalid lifetime") + } + + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil || c.resumptionSecret == nil { + return c.sendAlert(alertInternalError) + } + + // We need to save the max_early_data_size that the server sent us, in order + // to decide if we're going to try 0-RTT with this ticket. + // However, at the same time, the qtls.ClientSessionTicket needs to be equal to + // the tls.ClientSessionTicket, so we can't just add a new field to the struct. + // We therefore abuse the nonce field (which is a byte slice) + nonceWithEarlyData := make([]byte, len(msg.nonce)+4) + binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData) + copy(nonceWithEarlyData[4:], msg.nonce) + + var appData []byte + if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil { + appData = c.extraConfig.GetAppDataForSessionState() + } + var b cryptobyte.Builder + b.AddUint16(clientSessionStateVersion) // revision + b.AddUint32(msg.maxEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(appData) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(msg.nonce) + }) + + // Save the resumption_master_secret and nonce instead of deriving the PSK + // to do the least amount of work on NewSessionTicket messages before we + // know if the ticket will be used. Forward secrecy of resumed connections + // is guaranteed by the requirement for pskModeDHE. + session := &clientSessionState{ + sessionTicket: msg.label, + vers: c.vers, + cipherSuite: c.cipherSuite, + masterSecret: c.resumptionSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + nonce: b.BytesOrPanic(), + useBy: c.config.time().Add(lifetime), + ageAdd: msg.ageAdd, + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session)) + + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_messages.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_messages.go new file mode 100644 index 00000000..1ab75762 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_messages.go @@ -0,0 +1,1832 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "fmt" + "strings" + + "golang.org/x/crypto/cryptobyte" +) + +// The marshalingFunction type is an adapter to allow the use of ordinary +// functions as cryptobyte.MarshalingValue. +type marshalingFunction func(b *cryptobyte.Builder) error + +func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { + return f(b) +} + +// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If +// the length of the sequence is not the value specified, it produces an error. +func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { + b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { + if len(v) != n { + return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) + } + b.AddBytes(v) + return nil + })) +} + +// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. +func addUint64(b *cryptobyte.Builder, v uint64) { + b.AddUint32(uint32(v >> 32)) + b.AddUint32(uint32(v)) +} + +// readUint64 decodes a big-endian, 64-bit value into out and advances over it. +// It reports whether the read was successful. +func readUint64(s *cryptobyte.String, out *uint64) bool { + var hi, lo uint32 + if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { + return false + } + *out = uint64(hi)<<32 | uint64(lo) + return true +} + +// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) +} + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + serverName string + ocspStapling bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocols []string + scts bool + supportedVersions []uint16 + cookie []byte + keyShares []keyShare + earlyData bool + pskModes []uint8 + pskIdentities []pskIdentity + pskBinders [][]byte + additionalExtensions []Extension +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeClientHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + b.AddUint16(suite) + } + }) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.compressionMethods) + }) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + b.AddUint16(extensionServerName) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // name_type = host_name + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(1) // status_type = ocsp + b.AddUint16(0) // empty responder_id_list + b.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + b.AddUint16(extensionSupportedCurves) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + b.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + b.AddUint16(extensionSessionTicket) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + b.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ks := range m.keyShares { + b.AddUint16(uint16(ks.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + b.AddUint16(extensionPSKModes) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.pskModes) + }) + }) + } + for _, ext := range m.additionalExtensions { + b.AddUint16(ext.Type) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ext.Data) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(psk.label) + }) + b.AddUint32(psk.obfuscatedTicketAge) + } + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +// marshalWithoutBinders returns the ClientHello through the +// PreSharedKeyExtension.identities field, according to RFC 8446, Section +// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. +func (m *clientHelloMsg) marshalWithoutBinders() []byte { + bindersLen := 2 // uint16 length prefix + for _, binder := range m.pskBinders { + bindersLen += 1 // uint8 length prefix + bindersLen += len(binder) + } + + fullMessage := m.marshal() + return fullMessage[:len(fullMessage)-bindersLen] +} + +// updateBinders updates the m.pskBinders field, if necessary updating the +// cached marshaled representation. The supplied binders must have the same +// length as the current m.pskBinders. +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { + if len(pskBinders) != len(m.pskBinders) { + panic("tls: internal error: pskBinders length mismatch") + } + for i := range m.pskBinders { + if len(pskBinders[i]) != len(m.pskBinders[i]) { + panic("tls: internal error: pskBinders length mismatch") + } + } + m.pskBinders = pskBinders + if m.raw != nil { + lenWithoutBinders := len(m.marshalWithoutBinders()) + // TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported. + b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders]) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + if len(b.BytesOrPanic()) != len(m.raw) { + panic("tls: internal error: failed to update binders") + } + } +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + *m = clientHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) { + return false + } + + var cipherSuites cryptobyte.String + if !s.ReadUint16LengthPrefixed(&cipherSuites) { + return false + } + m.cipherSuites = []uint16{} + m.secureRenegotiationSupported = false + for !cipherSuites.Empty() { + var suite uint16 + if !cipherSuites.ReadUint16(&suite) { + return false + } + if suite == scsvRenegotiation { + m.secureRenegotiationSupported = true + } + m.cipherSuites = append(m.cipherSuites, suite) + } + + if !readUint8LengthPrefixed(&s, &m.compressionMethods) { + return false + } + + if s.Empty() { + // ClientHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var ext uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&ext) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch ext { + case extensionServerName: + // RFC 6066, Section 3 + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { + return false + } + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { + return false + } + if nameType != 0 { + continue + } + if len(m.serverName) != 0 { + // Multiple names of the same name_type are prohibited. + return false + } + m.serverName = string(serverName) + // An SNI value may not include a trailing dot. + if strings.HasSuffix(m.serverName, ".") { + return false + } + } + case extensionStatusRequest: + // RFC 4366, Section 3.6 + var statusType uint8 + var ignored cryptobyte.String + if !extData.ReadUint8(&statusType) || + !extData.ReadUint16LengthPrefixed(&ignored) || + !extData.ReadUint16LengthPrefixed(&ignored) { + return false + } + m.ocspStapling = statusType == statusTypeOCSP + case extensionSupportedCurves: + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + var curves cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { + return false + } + for !curves.Empty() { + var curve uint16 + if !curves.ReadUint16(&curve) { + return false + } + m.supportedCurves = append(m.supportedCurves, CurveID(curve)) + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + case extensionSessionTicket: + // RFC 5077, Section 3.2 + m.ticketSupported = true + extData.ReadBytes(&m.sessionTicket, len(extData)) + case extensionSignatureAlgorithms: + // RFC 5246, Section 7.4.1.4.1 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + // RFC 8446, Section 4.2.3 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionRenegotiationInfo: + // RFC 5746, Section 3.2 + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + // RFC 7301, Section 3.1 + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + for !protoList.Empty() { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(proto)) + } + case extensionSCT: + // RFC 6962, Section 3.3.1 + m.scts = true + case extensionSupportedVersions: + // RFC 8446, Section 4.2.1 + var versList cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { + return false + } + for !versList.Empty() { + var vers uint16 + if !versList.ReadUint16(&vers) { + return false + } + m.supportedVersions = append(m.supportedVersions, vers) + } + case extensionCookie: + // RFC 8446, Section 4.2.2 + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // RFC 8446, Section 4.2.8 + var clientShares cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&clientShares) { + return false + } + for !clientShares.Empty() { + var ks keyShare + if !clientShares.ReadUint16((*uint16)(&ks.group)) || + !readUint16LengthPrefixed(&clientShares, &ks.data) || + len(ks.data) == 0 { + return false + } + m.keyShares = append(m.keyShares, ks) + } + case extensionEarlyData: + // RFC 8446, Section 4.2.10 + m.earlyData = true + case extensionPSKModes: + // RFC 8446, Section 4.2.9 + if !readUint8LengthPrefixed(&extData, &m.pskModes) { + return false + } + case extensionPreSharedKey: + // RFC 8446, Section 4.2.11 + if !extensions.Empty() { + return false // pre_shared_key must be the last extension + } + var identities cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { + return false + } + for !identities.Empty() { + var psk pskIdentity + if !readUint16LengthPrefixed(&identities, &psk.label) || + !identities.ReadUint32(&psk.obfuscatedTicketAge) || + len(psk.label) == 0 { + return false + } + m.pskIdentities = append(m.pskIdentities, psk) + } + var binders cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { + return false + } + for !binders.Empty() { + var binder []byte + if !readUint8LengthPrefixed(&binders, &binder) || + len(binder) == 0 { + return false + } + m.pskBinders = append(m.pskBinders, binder) + } + default: + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + ocspStapling bool + ticketSupported bool + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocol string + scts [][]byte + supportedVersion uint16 + serverShare keyShare + selectedIdentityPresent bool + selectedIdentity uint16 + supportedPoints []uint8 + + // HelloRetryRequest extensions + cookie []byte + selectedGroup CurveID +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeServerHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16(m.cipherSuite) + b.AddUint8(m.compressionMethod) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + b.AddUint16(extensionSessionTicket) + b.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range m.scts { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.serverShare.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *serverHelloMsg) unmarshal(data []byte) bool { + *m = serverHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) || + !s.ReadUint16(&m.cipherSuite) || + !s.ReadUint8(&m.compressionMethod) { + return false + } + + if s.Empty() { + // ServerHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSessionTicket: + m.ticketSupported = true + case extensionRenegotiationInfo: + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + m.scts = append(m.scts, sct) + } + case extensionSupportedVersions: + if !extData.ReadUint16(&m.supportedVersion) { + return false + } + case extensionCookie: + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // This extension has different formats in SH and HRR, accept either + // and let the handshake logic decide. See RFC 8446, Section 4.2.8. + if len(extData) == 2 { + if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { + return false + } + } else { + if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || + !readUint16LengthPrefixed(&extData, &m.serverShare.data) { + return false + } + } + case extensionPreSharedKey: + m.selectedIdentityPresent = true + if !extData.ReadUint16(&m.selectedIdentity) { + return false + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type encryptedExtensionsMsg struct { + raw []byte + alpnProtocol string + earlyData bool + + additionalExtensions []Extension +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeEncryptedExtensions) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + for _, ext := range m.additionalExtensions { + b.AddUint16(ext.Type) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ext.Data) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + *m = encryptedExtensionsMsg{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var ext uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&ext) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch ext { + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionEarlyData: + m.earlyData = true + default: + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type endOfEarlyDataMsg struct{} + +func (m *endOfEarlyDataMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeEndOfEarlyData + return x +} + +func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type keyUpdateMsg struct { + raw []byte + updateRequested bool +} + +func (m *keyUpdateMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeKeyUpdate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.updateRequested { + b.AddUint8(1) + } else { + b.AddUint8(0) + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *keyUpdateMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var updateRequested uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&updateRequested) || !s.Empty() { + return false + } + switch updateRequested { + case 0: + m.updateRequested = false + case 1: + m.updateRequested = true + default: + return false + } + return true +} + +type newSessionTicketMsgTLS13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + nonce []byte + label []byte + maxEarlyData uint32 +} + +func (m *newSessionTicketMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeNewSessionTicket) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.lifetime) + b.AddUint32(m.ageAdd) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.nonce) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.label) + }) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.maxEarlyData > 0 { + b.AddUint16(extensionEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.maxEarlyData) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { + *m = newSessionTicketMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint32(&m.lifetime) || + !s.ReadUint32(&m.ageAdd) || + !readUint8LengthPrefixed(&s, &m.nonce) || + !readUint16LengthPrefixed(&s, &m.label) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionEarlyData: + if !extData.ReadUint32(&m.maxEarlyData) { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateRequestMsgTLS13 struct { + raw []byte + ocspStapling bool + scts bool + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateRequest) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + // certificate_request_context (SHALL be zero length unless used for + // post-handshake authentication) + b.AddUint8(0) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.scts { + // RFC 8446, Section 4.4.2.1 makes no mention of + // signed_certificate_timestamp in CertificateRequest, but + // "Extensions in the Certificate message from the client MUST + // correspond to extensions in the CertificateRequest message + // from the server." and it appears in the table in Section 4.2. + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedSignatureAlgorithms) > 0 { + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.certificateAuthorities) > 0 { + b.AddUint16(extensionCertificateAuthorities) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ca) + }) + } + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { + *m = certificateRequestMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context, extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSCT: + m.scts = true + case extensionSignatureAlgorithms: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionCertificateAuthorities: + var auths cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { + return false + } + for !auths.Empty() { + var ca []byte + if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { + return false + } + m.certificateAuthorities = append(m.certificateAuthorities, ca) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return true +} + +type certificateMsgTLS13 struct { + raw []byte + certificate Certificate + ocspStapling bool + scts bool +} + +func (m *certificateMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // certificate_request_context + + certificate := m.certificate + if !m.ocspStapling { + certificate.OCSPStaple = nil + } + if !m.scts { + certificate.SignedCertificateTimestamps = nil + } + marshalCertificate(b, certificate) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for i, cert := range certificate.Certificate { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if i > 0 { + // This library only supports OCSP and SCT for leaf certificates. + return + } + if certificate.OCSPStaple != nil { + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(certificate.OCSPStaple) + }) + }) + } + if certificate.SignedCertificateTimestamps != nil { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range certificate.SignedCertificateTimestamps { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + }) + } + }) +} + +func (m *certificateMsgTLS13) unmarshal(data []byte) bool { + *m = certificateMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !unmarshalCertificate(&s, &m.certificate) || + !s.Empty() { + return false + } + + m.scts = m.certificate.SignedCertificateTimestamps != nil + m.ocspStapling = m.certificate.OCSPStaple != nil + + return true +} + +func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + var extensions cryptobyte.String + if !readUint24LengthPrefixed(&certList, &cert) || + !certList.ReadUint16LengthPrefixed(&extensions) { + return false + } + certificate.Certificate = append(certificate.Certificate, cert) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + if len(certificate.Certificate) > 1 { + // This library only supports OCSP and SCT for leaf certificates. + continue + } + + switch extension { + case extensionStatusRequest: + var statusType uint8 + if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || + len(certificate.OCSPStaple) == 0 { + return false + } + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + certificate.SignedCertificateTimestamps = append( + certificate.SignedCertificateTimestamps, sct) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + } + return true +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.key = data[4:] + return true +} + +type certificateStatusMsg struct { + raw []byte + response []byte +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateStatus) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.response) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var statusType uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&s, &m.response) || + len(m.response) == 0 || !s.Empty() { + return false + } + return true +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return false + } + m.ciphertext = data[4:] + return true +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeFinished) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.verifyData) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + return s.Skip(1) && + readUint24LengthPrefixed(&s, &m.verifyData) && + s.Empty() +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAlgorithm indicates whether this message includes a list of + // supported signature algorithms. This change was introduced with TLS 1.2. + hasSignatureAlgorithm bool + + certificateTypes []byte + supportedSignatureAlgorithms []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 4346, Section 7.4.4. + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAlgorithm { + length += 2 + 2*len(m.supportedSignatureAlgorithms) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAlgorithm { + n := len(m.supportedSignatureAlgorithms) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAlgo := range m.supportedSignatureAlgorithms { + y[0] = uint8(sigAlgo >> 8) + y[1] = uint8(sigAlgo) + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return false + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return false + } + + data = data[numCertTypes:] + + if m.hasSignatureAlgorithm { + if len(data) < 2 { + return false + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return false + } + if len(data) < int(sigAndHashLen) { + return false + } + numSigAlgos := sigAndHashLen / 2 + m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) + for i := range m.supportedSignatureAlgorithms { + m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) + data = data[2:] + } + } + + if len(data) < 2 { + return false + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return false + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return false + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + + return len(data) == 0 +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAlgorithm bool // format change introduced in TLS 1.2 + signatureAlgorithm SignatureScheme + signature []byte +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateVerify) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + b.AddUint16(uint16(m.signatureAlgorithm)) + } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.signature) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + if !s.Skip(4) { // message type and uint24 length field + return false + } + if m.hasSignatureAlgorithm { + if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { + return false + } + } + return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 5077, Section 3.3. + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 10 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return false + } + + m.ticket = data[10:] + + return true +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server.go new file mode 100644 index 00000000..7d3557d7 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server.go @@ -0,0 +1,905 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "sync/atomic" + "time" +) + +// serverHandshakeState contains details of a server handshake in progress. +// It's discarded once the handshake has completed. +type serverHandshakeState struct { + c *Conn + ctx context.Context + clientHello *clientHelloMsg + hello *serverHelloMsg + suite *cipherSuite + ecdheOk bool + ecSignOk bool + rsaDecryptOk bool + rsaSignOk bool + sessionState *sessionState + finishedHash finishedHash + masterSecret []byte + cert *Certificate +} + +// serverHandshake performs a TLS handshake as a server. +func (c *Conn) serverHandshake(ctx context.Context) error { + c.setAlternativeRecordLayer() + + clientHello, err := c.readClientHello(ctx) + if err != nil { + return err + } + + if c.vers == VersionTLS13 { + hs := serverHandshakeStateTLS13{ + c: c, + ctx: ctx, + clientHello: clientHello, + } + return hs.handshake() + } else if c.extraConfig.usesAlternativeRecordLayer() { + // This should already have been caught by the check that the ClientHello doesn't + // offer any (supported) versions older than TLS 1.3. + // Check again to make sure we can't be tricked into using an older version. + c.sendAlert(alertProtocolVersion) + return errors.New("tls: negotiated TLS < 1.3 when using QUIC") + } + + hs := serverHandshakeState{ + c: c, + ctx: ctx, + clientHello: clientHello, + } + return hs.handshake() +} + +func (hs *serverHandshakeState) handshake() error { + c := hs.c + + if err := hs.processClientHello(); err != nil { + return err + } + + // For an overview of TLS handshaking, see RFC 5246, Section 7.3. + c.buffering = true + if hs.checkForResumption() { + // The client has included a session ticket and so we do an abbreviated handshake. + c.didResume = true + if err := hs.doResumeHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(c.serverFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = false + if err := hs.readFinished(nil); err != nil { + return err + } + } else { + // The client didn't include a session ticket, or it wasn't + // valid so we do a full handshake. + if err := hs.pickCipherSuite(); err != nil { + return err + } + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readFinished(c.clientFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = true + c.buffering = true + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(nil); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// readClientHello reads a ClientHello message and selects the protocol version. +func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { + msg, err := c.readHandshake() + if err != nil { + return nil, err + } + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return nil, unexpectedMessageError(clientHello, msg) + } + + var configForClient *config + originalConfig := c.config + if c.config.GetConfigForClient != nil { + chi := newClientHelloInfo(ctx, c, clientHello) + if cfc, err := c.config.GetConfigForClient(chi); err != nil { + c.sendAlert(alertInternalError) + return nil, err + } else if cfc != nil { + configForClient = fromConfig(cfc) + c.config = configForClient + } + } + c.ticketKeys = originalConfig.ticketKeys(configForClient) + + clientVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + clientVersions = supportedVersionsFromMax(clientHello.vers) + } + if c.extraConfig.usesAlternativeRecordLayer() { + // In QUIC, the client MUST NOT offer any old TLS versions. + // Here, we can only check that none of the other supported versions of this library + // (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here. + for _, ver := range clientVersions { + if ver == VersionTLS13 { + continue + } + for _, v := range supportedVersions { + if ver == v { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver) + } + } + } + // Make the config we're using allows us to use TLS 1.3. + if c.config.maxSupportedVersion() < VersionTLS13 { + c.sendAlert(alertInternalError) + return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") + } + } + c.vers, ok = c.config.mutualVersion(clientVersions) + if !ok { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) + } + c.haveVers = true + c.in.version = c.vers + c.out.version = c.vers + + return clientHello, nil +} + +func (hs *serverHandshakeState) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers + + foundCompression := false + // We only support null compression, so check that the client offered it. + for _, compression := range hs.clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true + break + } + } + + if !foundCompression { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client does not support uncompressed connections") + } + + hs.hello.random = make([]byte, 32) + serverRandom := hs.hello.random + // Downgrade protection canaries. See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion() + if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { + if c.vers == VersionTLS12 { + copy(serverRandom[24:], downgradeCanaryTLS12) + } else { + copy(serverRandom[24:], downgradeCanaryTLS11) + } + serverRandom = serverRandom[:24] + } + _, err := io.ReadFull(c.config.rand(), serverRandom) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported + hs.hello.compressionMethod = compressionNone + if len(hs.clientHello.serverName) > 0 { + c.serverName = hs.clientHello.serverName + } + + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + if err != nil { + c.sendAlert(alertNoApplicationProtocol) + return err + } + hs.hello.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + + hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + if hs.clientHello.scts { + hs.hello.scts = hs.cert.SignedCertificateTimestamps + } + + hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) + + if hs.ecdheOk { + // Although omitting the ec_point_formats extension is permitted, some + // old OpenSSL version will refuse to handshake if not present. + // + // Per RFC 4492, section 5.1.2, implementations MUST support the + // uncompressed point format. See golang.org/issue/31943. + hs.hello.supportedPoints = []uint8{pointFormatUncompressed} + } + + if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { + switch priv.Public().(type) { + case *ecdsa.PublicKey: + hs.ecSignOk = true + case ed25519.PublicKey: + hs.ecSignOk = true + case *rsa.PublicKey: + hs.rsaSignOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) + } + } + if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { + switch priv.Public().(type) { + case *rsa.PublicKey: + hs.rsaDecryptOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) + } + } + + return nil +} + +// negotiateALPN picks a shared ALPN protocol that both sides support in server +// preference order. If ALPN is not configured or the peer doesn't support it, +// it returns "" and no error. +func negotiateALPN(serverProtos, clientProtos []string) (string, error) { + if len(serverProtos) == 0 || len(clientProtos) == 0 { + return "", nil + } + var http11fallback bool + for _, s := range serverProtos { + for _, c := range clientProtos { + if s == c { + return s, nil + } + if s == "h2" && c == "http/1.1" { + http11fallback = true + } + } + } + // As a special case, let http/1.1 clients connect to h2 servers as if they + // didn't support ALPN. We used not to enforce protocol overlap, so over + // time a number of HTTP servers were configured with only "h2", but + // expected to accept connections from "http/1.1" clients. See Issue 46310. + if http11fallback { + return "", nil + } + return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos) +} + +// supportsECDHE returns whether ECDHE key exchanges can be used with this +// pre-TLS 1.3 client. +func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool { + supportsCurve := false + for _, curve := range supportedCurves { + if c.supportsCurve(curve) { + supportsCurve = true + break + } + } + + supportsPointFormat := false + for _, pointFormat := range supportedPoints { + if pointFormat == pointFormatUncompressed { + supportsPointFormat = true + break + } + } + + return supportsCurve && supportsPointFormat +} + +func (hs *serverHandshakeState) pickCipherSuite() error { + c := hs.c + + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceOrder = cipherSuitesPreferenceOrderNoAES + } + + configCipherSuites := c.config.cipherSuites() + preferenceList := make([]uint16, 0, len(configCipherSuites)) + for _, suiteID := range preferenceOrder { + for _, id := range configCipherSuites { + if id == suiteID { + preferenceList = append(preferenceList, id) + break + } + } + } + + hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk) + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // The client is doing a fallback connection. See RFC 7507. + if hs.clientHello.vers < c.config.maxSupportedVersion() { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + return nil +} + +func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + if !hs.ecdheOk { + return false + } + if c.flags&suiteECSign != 0 { + if !hs.ecSignOk { + return false + } + } else if !hs.rsaSignOk { + return false + } + } else if !hs.rsaDecryptOk { + return false + } + if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true +} + +// checkForResumption reports whether we should perform resumption on this connection. +func (hs *serverHandshakeState) checkForResumption() bool { + c := hs.c + + if c.config.SessionTicketsDisabled { + return false + } + + plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) + if plaintext == nil { + return false + } + hs.sessionState = &sessionState{usedOldKey: usedOldKey} + ok := hs.sessionState.unmarshal(plaintext) + if !ok { + return false + } + + createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + return false + } + + // Never resume a session for a different TLS version. + if c.vers != hs.sessionState.vers { + return false + } + + cipherSuiteOk := false + // Check that the client is still offering the ciphersuite in the session. + for _, id := range hs.clientHello.cipherSuites { + if id == hs.sessionState.cipherSuite { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return false + } + + // Check that we also support the ciphersuite from the session. + hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, + c.config.cipherSuites(), hs.cipherSuiteOk) + if hs.suite == nil { + return false + } + + sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + return false + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + return false + } + + return true +} + +func (hs *serverHandshakeState) doResumeHandshake() error { + c := hs.c + + hs.hello.cipherSuite = hs.suite.id + c.cipherSuite = hs.suite.id + // We echo the client's session ID in the ServerHello to let it know + // that we're doing a resumption. + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.ticketSupported = hs.sessionState.usedOldKey + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := c.processCertsFromClient(Certificate{ + Certificate: hs.sessionState.certificates, + }); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + hs.masterSecret = hs.sessionState.masterSecret + + return nil +} + +func (hs *serverHandshakeState) doFullHandshake() error { + c := hs.c + + if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { + hs.hello.ocspStapling = true + } + + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled + hs.hello.cipherSuite = hs.suite.id + + hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) + if c.config.ClientAuth == NoClientCert { + // No need to keep a full record of the handshake if client + // certificates won't be used. + hs.finishedHash.discardHandshakeBuffer() + } + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + certMsg := new(certificateMsg) + certMsg.certificates = hs.cert.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + if hs.hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.response = hs.cert.OCSPStaple + hs.finishedHash.Write(certStatus.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + return err + } + } + + keyAgreement := hs.suite.ka(c.vers) + skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if skx != nil { + hs.finishedHash.Write(skx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + return err + } + } + + var certReq *certificateRequestMsg + if c.config.ClientAuth >= RequestClientCert { + // Request a client certificate + certReq = new(certificateRequestMsg) + certReq.certificateTypes = []byte{ + byte(certTypeRSASign), + byte(certTypeECDSASign), + } + if c.vers >= VersionTLS12 { + certReq.hasSignatureAlgorithm = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + hs.finishedHash.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + helloDone := new(serverHelloDoneMsg) + hs.finishedHash.Write(helloDone.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + return err + } + + if _, err := c.flush(); err != nil { + return err + } + + var pub crypto.PublicKey // public key for client auth, if any + + msg, err := c.readHandshake() + if err != nil { + return err + } + + // If we requested a client certificate, then the client must send a + // certificate message, even if it's empty. + if c.config.ClientAuth >= RequestClientCert { + certMsg, ok := msg.(*certificateMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(Certificate{ + Certificate: certMsg.certificates, + }); err != nil { + return err + } + if len(certMsg.certificates) != 0 { + pub = c.peerCertificates[0].PublicKey + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + // Get client key exchange + ckx, ok := msg.(*clientKeyExchangeMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) + } + hs.finishedHash.Write(ckx.marshal()) + + preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + + // If we received a client cert in response to our certificate request message, + // the client will send us a certificateVerifyMsg immediately after the + // clientKeyExchangeMsg. This message is a digest of all preceding + // handshake-layer messages that is signed using the private key corresponding + // to the client's certificate. This allows us to verify that the client is in + // possession of the private key of the certificate. + if len(c.peerCertificates) > 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.finishedHash.Write(certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *serverHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + + var clientCipher, serverCipher interface{} + var clientHash, serverHash hash.Hash + + if hs.suite.aead == nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) + c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) + + return nil +} + +func (hs *serverHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + + verify := hs.finishedHash.clientSum(hs.masterSecret) + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client's Finished message is incorrect") + } + + hs.finishedHash.Write(clientFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket() error { + // ticketSupported is set in a resumption handshake if the + // ticket from the client was encrypted with an old session + // ticket key and thus a refreshed ticket should be sent. + if !hs.hello.ticketSupported { + return nil + } + + c := hs.c + m := new(newSessionTicketMsg) + + createdAt := uint64(c.config.time().Unix()) + if hs.sessionState != nil { + // If this is re-wrapping an old key, then keep + // the original time it was created. + createdAt = hs.sessionState.createdAt + } + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + createdAt: createdAt, + masterSecret: hs.masterSecret, + certificates: certsFromClient, + } + var err error + m.ticket, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + + hs.finishedHash.Write(m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + copy(out, finished.verifyData) + + return nil +} + +// processCertsFromClient takes a chain of client certificates either from a +// Certificates message or from a sessionState and verifies them. It returns +// the public key of the leaf certificate. +func (c *Conn) processCertsFromClient(certificate Certificate) error { + certificates := certificate.Certificate + certs := make([]*x509.Certificate, len(certificates)) + var err error + for i, asn1Data := range certificates { + if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse client certificate: " + err.Error()) + } + } + + if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: client didn't provide a certificate") + } + + if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { + opts := x509.VerifyOptions{ + Roots: c.config.ClientCAs, + CurrentTime: c.config.time(), + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + chains, err := certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to verify client certificate: " + err.Error()) + } + + c.verifiedChains = chains + } + + c.peerCertificates = certs + c.ocspResponse = certificate.OCSPStaple + c.scts = certificate.SignedCertificateTimestamps + + if len(certs) > 0 { + switch certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + } + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { + supportedVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + supportedVersions = supportedVersionsFromMax(clientHello.vers) + } + + return toClientHelloInfo(&clientHelloInfo{ + CipherSuites: clientHello.cipherSuites, + ServerName: clientHello.serverName, + SupportedCurves: clientHello.supportedCurves, + SupportedPoints: clientHello.supportedPoints, + SignatureSchemes: clientHello.supportedSignatureAlgorithms, + SupportedProtos: clientHello.alpnProtocols, + SupportedVersions: supportedVersions, + Conn: c.conn, + config: toConfig(c.config), + ctx: ctx, + }) +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server_tls13.go new file mode 100644 index 00000000..0c200605 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server_tls13.go @@ -0,0 +1,896 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/hmac" + "crypto/rsa" + "errors" + "hash" + "io" + "sync/atomic" + "time" +) + +// maxClientPSKIdentities is the number of client PSK identities the server will +// attempt to validate. It will ignore the rest not to let cheap ClientHello +// messages cause too much work in session ticket decryption attempts. +const maxClientPSKIdentities = 5 + +type serverHandshakeStateTLS13 struct { + c *Conn + ctx context.Context + clientHello *clientHelloMsg + hello *serverHelloMsg + alpnNegotiationErr error + encryptedExtensions *encryptedExtensionsMsg + sentDummyCCS bool + usingPSK bool + suite *cipherSuiteTLS13 + cert *Certificate + sigAlg SignatureScheme + earlySecret []byte + sharedKey []byte + handshakeSecret []byte + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 + transcript hash.Hash + clientFinished []byte +} + +func (hs *serverHandshakeStateTLS13) handshake() error { + c := hs.c + + // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. + if err := hs.processClientHello(); err != nil { + return err + } + if err := hs.checkForResumption(); err != nil { + return err + } + if err := hs.pickCertificate(); err != nil { + return err + } + c.buffering = true + if err := hs.sendServerParameters(); err != nil { + return err + } + if err := hs.sendServerCertificate(); err != nil { + return err + } + if err := hs.sendServerFinished(); err != nil { + return err + } + // Note that at this point we could start sending application data without + // waiting for the client's second flight, but the application might not + // expect the lack of replay protection of the ClientHello parameters. + if _, err := c.flush(); err != nil { + return err + } + if err := hs.readClientCertificate(); err != nil { + return err + } + if err := hs.readClientFinished(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *serverHandshakeStateTLS13) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.encryptedExtensions = new(encryptedExtensionsMsg) + + // TLS 1.3 froze the ServerHello.legacy_version field, and uses + // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. + hs.hello.vers = VersionTLS12 + hs.hello.supportedVersion = c.vers + + if len(hs.clientHello.supportedVersions) == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") + } + + // Abort if the client is doing a fallback and landing lower than what we + // support. See RFC 7507, which however does not specify the interaction + // with supported_versions. The only difference is that with + // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] + // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, + // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to + // TLS 1.2, because a TLS 1.3 server would abort here. The situation before + // supported_versions was not better because there was just no way to do a + // TLS 1.4 handshake without risking the server selecting TLS 1.3. + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // Use c.vers instead of max(supported_versions) because an attacker + // could defeat this by adding an arbitrary high version otherwise. + if c.vers < c.config.maxSupportedVersion() { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + if len(hs.clientHello.compressionMethods) != 1 || + hs.clientHello.compressionMethods[0] != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: TLS 1.3 client supports illegal compression methods") + } + + hs.hello.random = make([]byte, 32) + if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.compressionMethod = compressionNone + + if hs.suite == nil { + var preferenceList []uint16 + for _, suiteID := range c.config.CipherSuites { + for _, suite := range cipherSuitesTLS13 { + if suite.id == suiteID { + preferenceList = append(preferenceList, suiteID) + break + } + } + } + if len(preferenceList) == 0 { + preferenceList = defaultCipherSuitesTLS13 + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = defaultCipherSuitesTLS13NoAES + } + } + for _, suiteID := range preferenceList { + hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) + if hs.suite != nil { + break + } + } + } + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + hs.hello.cipherSuite = hs.suite.id + hs.transcript = hs.suite.hash.New() + + // Pick the ECDHE group in server preference order, but give priority to + // groups with a key share, to avoid a HelloRetryRequest round-trip. + var selectedGroup CurveID + var clientKeyShare *keyShare +GroupSelection: + for _, preferredGroup := range c.config.curvePreferences() { + for _, ks := range hs.clientHello.keyShares { + if ks.group == preferredGroup { + selectedGroup = ks.group + clientKeyShare = &ks + break GroupSelection + } + } + if selectedGroup != 0 { + continue + } + for _, group := range hs.clientHello.supportedCurves { + if group == preferredGroup { + selectedGroup = group + break + } + } + } + if selectedGroup == 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no ECDHE curve supported by both client and server") + } + if clientKeyShare == nil { + if err := hs.doHelloRetryRequest(selectedGroup); err != nil { + return err + } + clientKeyShare = &hs.clientHello.keyShares[0] + } + + if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} + hs.sharedKey = params.SharedKey(clientKeyShare.data) + if hs.sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid client key share") + } + + c.serverName = hs.clientHello.serverName + + if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil { + c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions) + } + + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + if err != nil { + hs.alpnNegotiationErr = err + } + hs.encryptedExtensions.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + + return nil +} + +func (hs *serverHandshakeStateTLS13) checkForResumption() error { + c := hs.c + + if c.config.SessionTicketsDisabled { + return nil + } + + modeOK := false + for _, mode := range hs.clientHello.pskModes { + if mode == pskModeDHE { + modeOK = true + break + } + } + if !modeOK { + return nil + } + + if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid or missing PSK binders") + } + if len(hs.clientHello.pskIdentities) == 0 { + return nil + } + + for i, identity := range hs.clientHello.pskIdentities { + if i >= maxClientPSKIdentities { + break + } + + plaintext, _ := c.decryptTicket(identity.label) + if plaintext == nil { + continue + } + sessionState := new(sessionStateTLS13) + if ok := sessionState.unmarshal(plaintext); !ok { + continue + } + + if hs.clientHello.earlyData { + if sessionState.maxEarlyData == 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: client sent unexpected early data") + } + + if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol && + c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 && + c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) { + hs.encryptedExtensions.earlyData = true + c.used0RTT = true + } + } + + createdAt := time.Unix(int64(sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + continue + } + + // We don't check the obfuscated ticket age because it's affected by + // clock skew and it's only a freshness signal useful for shrinking the + // window for replay attacks, which don't affect us as we don't do 0-RTT. + + pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) + if pskSuite == nil || pskSuite.hash != hs.suite.hash { + continue + } + + // PSK connections don't re-establish client certificates, but carry + // them over in the session ticket. Ensure the presence of client certs + // in the ticket is consistent with the configured requirements. + sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + continue + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + continue + } + + psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", + nil, hs.suite.hash.Size()) + hs.earlySecret = hs.suite.extract(psk, nil) + binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) + // Clone the transcript in case a HelloRetryRequest was recorded. + transcript := cloneHash(hs.transcript, hs.suite.hash) + if transcript == nil { + c.sendAlert(alertInternalError) + return errors.New("tls: internal error: failed to clone hash") + } + transcript.Write(hs.clientHello.marshalWithoutBinders()) + pskBinder := hs.suite.finishedHash(binderKey, transcript) + if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid PSK binder") + } + + c.didResume = true + if err := c.processCertsFromClient(sessionState.certificate); err != nil { + return err + } + + h := cloneHash(hs.transcript, hs.suite.hash) + h.Write(hs.clientHello.marshal()) + if hs.encryptedExtensions.earlyData { + clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h) + c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret) + if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + } + + hs.hello.selectedIdentityPresent = true + hs.hello.selectedIdentity = uint16(i) + hs.usingPSK = true + return nil + } + + return nil +} + +// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// interfaces implemented by standard library hashes to clone the state of in +// to a new instance of h. It returns nil if the operation fails. +func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { + // Recreate the interface to avoid importing encoding. + type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) + UnmarshalBinary(data []byte) error + } + marshaler, ok := in.(binaryMarshaler) + if !ok { + return nil + } + state, err := marshaler.MarshalBinary() + if err != nil { + return nil + } + out := h.New() + unmarshaler, ok := out.(binaryMarshaler) + if !ok { + return nil + } + if err := unmarshaler.UnmarshalBinary(state); err != nil { + return nil + } + return out +} + +func (hs *serverHandshakeStateTLS13) pickCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. + if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { + return c.sendAlert(alertMissingExtension) + } + + certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) + if err != nil { + // getCertificate returned a certificate that is unsupported or + // incompatible with the client's signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + hs.cert = certificate + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. See RFC 8446, Section 4.4.1. + hs.transcript.Write(hs.clientHello.marshal()) + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + + helloRetryRequest := &serverHelloMsg{ + vers: hs.hello.vers, + random: helloRetryRequestRandom, + sessionId: hs.hello.sessionId, + cipherSuite: hs.hello.cipherSuite, + compressionMethod: hs.hello.compressionMethod, + supportedVersion: hs.hello.supportedVersion, + selectedGroup: selectedGroup, + } + + hs.transcript.Write(helloRetryRequest.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientHello, msg) + } + + if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client sent invalid key share in second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client indicated early data in second ClientHello") + } + + if illegalClientHelloChange(clientHello, hs.clientHello) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client illegally modified second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client offered 0-RTT data in second ClientHello") + } + + hs.clientHello = clientHello + return nil +} + +// illegalClientHelloChange reports whether the two ClientHello messages are +// different, with the exception of the changes allowed before and after a +// HelloRetryRequest. See RFC 8446, Section 4.1.2. +func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { + if len(ch.supportedVersions) != len(ch1.supportedVersions) || + len(ch.cipherSuites) != len(ch1.cipherSuites) || + len(ch.supportedCurves) != len(ch1.supportedCurves) || + len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || + len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || + len(ch.alpnProtocols) != len(ch1.alpnProtocols) { + return true + } + for i := range ch.supportedVersions { + if ch.supportedVersions[i] != ch1.supportedVersions[i] { + return true + } + } + for i := range ch.cipherSuites { + if ch.cipherSuites[i] != ch1.cipherSuites[i] { + return true + } + } + for i := range ch.supportedCurves { + if ch.supportedCurves[i] != ch1.supportedCurves[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithms { + if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithmsCert { + if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { + return true + } + } + for i := range ch.alpnProtocols { + if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { + return true + } + } + return ch.vers != ch1.vers || + !bytes.Equal(ch.random, ch1.random) || + !bytes.Equal(ch.sessionId, ch1.sessionId) || + !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || + ch.serverName != ch1.serverName || + ch.ocspStapling != ch1.ocspStapling || + !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || + ch.ticketSupported != ch1.ticketSupported || + !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || + ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || + !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || + ch.scts != ch1.scts || + !bytes.Equal(ch.cookie, ch1.cookie) || + !bytes.Equal(ch.pskModes, ch1.pskModes) +} + +func (hs *serverHandshakeStateTLS13) sendServerParameters() error { + c := hs.c + + hs.transcript.Write(hs.clientHello.marshal()) + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + earlySecret := hs.earlySecret + if earlySecret == nil { + earlySecret = hs.suite.extract(nil, nil) + } + hs.handshakeSecret = hs.suite.extract(hs.sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret) + c.in.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if hs.alpnNegotiationErr != nil { + c.sendAlert(alertNoApplicationProtocol) + return hs.alpnNegotiationErr + } + if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil { + hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions) + } + + hs.transcript.Write(hs.encryptedExtensions.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.encryptedExtensions.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) requestClientCert() bool { + return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK +} + +func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + if hs.requestClientCert() { + // Request a client certificate + certReq := new(certificateRequestMsgTLS13) + certReq.ocspStapling = true + certReq.scts = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + + hs.transcript.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *hs.cert + certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + certVerifyMsg.signatureAlgorithm = hs.sigAlg + + sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + public := hs.cert.PrivateKey.(crypto.Signer).Public() + if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && + rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS + c.sendAlert(alertHandshakeFailure) + } else { + c.sendAlert(alertInternalError) + } + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) sendServerFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + // Derive secrets that take context through the server Finished. + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionApplication, hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + // If we did not request client certificates, at this point we can + // precompute the client finished and roll the transcript forward to send + // session tickets in our first flight. + if !hs.requestClientCert() { + if err := hs.sendSessionTickets(); err != nil { + return err + } + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { + if hs.c.config.SessionTicketsDisabled { + return false + } + + // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. + for _, pskMode := range hs.clientHello.pskModes { + if pskMode == pskModeDHE { + return true + } + } + return false +} + +func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { + c := hs.c + + hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + finishedMsg := &finishedMsg{ + verifyData: hs.clientFinished, + } + hs.transcript.Write(finishedMsg.marshal()) + + if !hs.shouldSendSessionTickets() { + return nil + } + + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + + // Don't send session tickets when the alternative record layer is set. + // Instead, save the resumption secret on the Conn. + // Session tickets can then be generated by calling Conn.GetSessionTicket(). + if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil { + return nil + } + + m, err := hs.c.getSessionTicketMsg(nil) + if err != nil { + return err + } + + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientCertificate() error { + c := hs.c + + if !hs.requestClientCert() { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + // If we requested a client certificate, then the client must send a + // certificate message. If it's empty, no CertificateVerify is sent. + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.transcript.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(certMsg.certificate); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if len(certMsg.certificate.Certificate) != 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + } + + // If we waited until the client certificates to send session tickets, we + // are ready to do it now. + if err := hs.sendSessionTickets(); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + if !hmac.Equal(hs.clientFinished, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid client finished hash") + } + + c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) + c.in.setTrafficSecret(hs.suite, hs.trafficSecret) + + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/key_agreement.go b/vendor/github.com/marten-seemann/qtls-go1-17/key_agreement.go new file mode 100644 index 00000000..453a8dcf --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/key_agreement.go @@ -0,0 +1,357 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/md5" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "errors" + "fmt" + "io" +) + +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error + generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) +} + +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + +// rsaKeyAgreement implements the standard TLS key agreement where the client +// encrypts the pre-master secret to the server's public key. +type rsaKeyAgreement struct{} + +func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + return nil, nil +} + +func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) + if ciphertextLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + ciphertext := ckx.ciphertext[2:] + + priv, ok := cert.PrivateKey.(crypto.Decrypter) + if !ok { + return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") + } + // Perform constant time RSA PKCS #1 v1.5 decryption + preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) + if err != nil { + return nil, err + } + // We don't check the version number in the premaster secret. For one, + // by checking it, we would leak information about the validity of the + // encrypted pre-master secret. Secondly, it provides only a small + // benefit against a downgrade attack and some implementations send the + // wrong version anyway. See the discussion at the end of section + // 7.4.7.1 of RFC 4346. + return preMasterSecret, nil +} + +func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + return errors.New("tls: unexpected ServerKeyExchange") +} + +func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + preMasterSecret := make([]byte, 48) + preMasterSecret[0] = byte(clientHello.vers >> 8) + preMasterSecret[1] = byte(clientHello.vers) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, nil, err + } + + rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") + } + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, len(encrypted)+2) + ckx.ciphertext[0] = byte(len(encrypted) >> 8) + ckx.ciphertext[1] = byte(len(encrypted)) + copy(ckx.ciphertext[2:], encrypted) + return preMasterSecret, ckx, nil +} + +// sha1Hash calculates a SHA1 hash over the given byte slices. +func sha1Hash(slices [][]byte) []byte { + hsha1 := sha1.New() + for _, slice := range slices { + hsha1.Write(slice) + } + return hsha1.Sum(nil) +} + +// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the +// concatenation of an MD5 and SHA1 hash. +func md5SHA1Hash(slices [][]byte) []byte { + md5sha1 := make([]byte, md5.Size+sha1.Size) + hmd5 := md5.New() + for _, slice := range slices { + hmd5.Write(slice) + } + copy(md5sha1, hmd5.Sum(nil)) + copy(md5sha1[md5.Size:], sha1Hash(slices)) + return md5sha1 +} + +// hashForServerKeyExchange hashes the given slices and returns their digest +// using the given hash function (for >= TLS 1.2) or using a default based on +// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't +// do pre-hashing, it returns the concatenation of the slices. +func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { + if sigType == signatureEd25519 { + var signed []byte + for _, slice := range slices { + signed = append(signed, slice...) + } + return signed + } + if version >= VersionTLS12 { + h := hashFunc.New() + for _, slice := range slices { + h.Write(slice) + } + digest := h.Sum(nil) + return digest + } + if sigType == signatureECDSA { + return sha1Hash(slices) + } + return md5SHA1Hash(slices) +} + +// ecdheKeyAgreement implements a TLS key agreement where the server +// generates an ephemeral EC public/private key pair and signs it. The +// pre-master secret is then calculated using ECDH. The signature may +// be ECDSA, Ed25519 or RSA. +type ecdheKeyAgreement struct { + version uint16 + isRSA bool + params ecdheParameters + + // ckx and preMasterSecret are generated in processServerKeyExchange + // and returned in generateClientKeyExchange. + ckx *clientKeyExchangeMsg + preMasterSecret []byte +} + +func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + var curveID CurveID + for _, c := range clientHello.supportedCurves { + if config.supportsCurve(c) { + curveID = c + break + } + } + + if curveID == 0 { + return nil, errors.New("tls: no supported elliptic curves offered") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, err + } + ka.params = params + + // See RFC 4492, Section 5.4. + ecdhePublic := params.PublicKey() + serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) + serverECDHEParams[0] = 3 // named curve + serverECDHEParams[1] = byte(curveID >> 8) + serverECDHEParams[2] = byte(curveID) + serverECDHEParams[3] = byte(len(ecdhePublic)) + copy(serverECDHEParams[4:], ecdhePublic) + + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) + } + + var signatureAlgorithm SignatureScheme + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) + if err != nil { + return nil, err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return nil, err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) + if err != nil { + return nil, err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) + + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := priv.Sign(config.rand(), signed, signOpts) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } + + skx := new(serverKeyExchangeMsg) + sigAndHashLen := 0 + if ka.version >= VersionTLS12 { + sigAndHashLen = 2 + } + skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) + copy(skx.key, serverECDHEParams) + k := skx.key[len(serverECDHEParams):] + if ka.version >= VersionTLS12 { + k[0] = byte(signatureAlgorithm >> 8) + k[1] = byte(signatureAlgorithm) + k = k[2:] + } + k[0] = byte(len(sig) >> 8) + k[1] = byte(len(sig)) + copy(k[2:], sig) + + return skx, nil +} + +func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { + return nil, errClientKeyExchange + } + + preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) + if preMasterSecret == nil { + return nil, errClientKeyExchange + } + + return preMasterSecret, nil +} + +func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 4 { + return errServerKeyExchange + } + if skx.key[0] != 3 { // named curve + return errors.New("tls: server selected unsupported curve") + } + curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + + publicLen := int(skx.key[3]) + if publicLen+4 > len(skx.key) { + return errServerKeyExchange + } + serverECDHEParams := skx.key[:4+publicLen] + publicKey := serverECDHEParams[4:] + + sig := skx.key[4+publicLen:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return errors.New("tls: server selected unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return err + } + ka.params = params + + ka.preMasterSecret = params.SharedKey(publicKey) + if ka.preMasterSecret == nil { + return errServerKeyExchange + } + + ourPublicKey := params.PublicKey() + ka.ckx = new(clientKeyExchangeMsg) + ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) + ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) + copy(ka.ckx.ciphertext[1:], ourPublicKey) + + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) + sig = sig[2:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) + if err != nil { + return err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return errServerKeyExchange + } + + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) + if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + return nil +} + +func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + if ka.ckx == nil { + return nil, nil, errors.New("tls: missing ServerKeyExchange message") + } + + return ka.preMasterSecret, ka.ckx, nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/key_schedule.go b/vendor/github.com/marten-seemann/qtls-go1-17/key_schedule.go new file mode 100644 index 00000000..da13904a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/key_schedule.go @@ -0,0 +1,199 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto/elliptic" + "crypto/hmac" + "errors" + "hash" + "io" + "math/big" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +// This file contains the functions necessary to compute the TLS 1.3 key +// schedule. See RFC 8446, Section 7. + +const ( + resumptionBinderLabel = "res binder" + clientHandshakeTrafficLabel = "c hs traffic" + serverHandshakeTrafficLabel = "s hs traffic" + clientApplicationTrafficLabel = "c ap traffic" + serverApplicationTrafficLabel = "s ap traffic" + exporterLabel = "exp master" + resumptionLabel = "res master" + trafficUpdateLabel = "traffic upd" +) + +// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("tls: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} + +// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { + if transcript == nil { + transcript = c.hash.New() + } + return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) +} + +// extract implements HKDF-Extract with the cipher suite hash. +func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { + if newSecret == nil { + newSecret = make([]byte, c.hash.Size()) + } + return hkdf.Extract(c.hash.New, newSecret, currentSecret) +} + +// nextTrafficSecret generates the next traffic secret, given the current one, +// according to RFC 8446, Section 7.2. +func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { + return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) +} + +// trafficKey generates traffic keys according to RFC 8446, Section 7.3. +func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { + key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) + iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) + return +} + +// finishedHash generates the Finished verify_data or PskBinderEntry according +// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey +// selection. +func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { + finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) + verifyData := hmac.New(c.hash.New, finishedKey) + verifyData.Write(transcript.Sum(nil)) + return verifyData.Sum(nil) +} + +// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to +// RFC 8446, Section 7.5. +func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { + expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) + return func(label string, context []byte, length int) ([]byte, error) { + secret := c.deriveSecret(expMasterSecret, label, nil) + h := c.hash.New() + h.Write(context) + return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil + } +} + +// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, +// according to RFC 8446, Section 4.2.8.2. +type ecdheParameters interface { + CurveID() CurveID + PublicKey() []byte + SharedKey(peerPublicKey []byte) []byte +} + +func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { + if curveID == X25519 { + privateKey := make([]byte, curve25519.ScalarSize) + if _, err := io.ReadFull(rand, privateKey); err != nil { + return nil, err + } + publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) + if err != nil { + return nil, err + } + return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil + } + + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, errors.New("tls: internal error: unsupported curve") + } + + p := &nistParameters{curveID: curveID} + var err error + p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) + if err != nil { + return nil, err + } + return p, nil +} + +func curveForCurveID(id CurveID) (elliptic.Curve, bool) { + switch id { + case CurveP256: + return elliptic.P256(), true + case CurveP384: + return elliptic.P384(), true + case CurveP521: + return elliptic.P521(), true + default: + return nil, false + } +} + +type nistParameters struct { + privateKey []byte + x, y *big.Int // public key + curveID CurveID +} + +func (p *nistParameters) CurveID() CurveID { + return p.curveID +} + +func (p *nistParameters) PublicKey() []byte { + curve, _ := curveForCurveID(p.curveID) + return elliptic.Marshal(curve, p.x, p.y) +} + +func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { + curve, _ := curveForCurveID(p.curveID) + // Unmarshal also checks whether the given point is on the curve. + x, y := elliptic.Unmarshal(curve, peerPublicKey) + if x == nil { + return nil + } + + xShared, _ := curve.ScalarMult(x, y, p.privateKey) + sharedKey := make([]byte, (curve.Params().BitSize+7)/8) + return xShared.FillBytes(sharedKey) +} + +type x25519Parameters struct { + privateKey []byte + publicKey []byte +} + +func (p *x25519Parameters) CurveID() CurveID { + return X25519 +} + +func (p *x25519Parameters) PublicKey() []byte { + return p.publicKey[:] +} + +func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { + sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) + if err != nil { + return nil + } + return sharedKey +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/prf.go b/vendor/github.com/marten-seemann/qtls-go1-17/prf.go new file mode 100644 index 00000000..9eb0221a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/prf.go @@ -0,0 +1,283 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" +) + +// Split a premaster secret in two as specified in RFC 4346, Section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return +} + +// pHash implements the P_hash function, as defined in RFC 4346, Section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} + +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. +func prf10(result, secret, label, seed []byte) { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + + for i, b := range result2 { + result[i] ^= b + } +} + +// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + pHash(result, secret, labelAndSeed, hashFunc) + } +} + +const ( + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. +) + +var masterSecretLabel = []byte("master secret") +var keyExpansionLabel = []byte("key expansion") +var clientFinishedLabel = []byte("client finished") +var serverFinishedLabel = []byte("server finished") + +func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { + switch version { + case VersionTLS10, VersionTLS11: + return prf10, crypto.Hash(0) + case VersionTLS12: + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 + default: + panic("unknown version") + } +} + +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + prf, _ := prfAndHashForVersion(version, suite) + return prf +} + +// masterFromPreMasterSecret generates the master secret from the pre-master +// secret. See RFC 5246, Section 8.1. +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { + seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) + return masterSecret +} + +// keysFromMasterSecret generates the connection keys from the master +// secret, given the lengths of the MAC key, cipher key and IV, as defined in +// RFC 2246, Section 6.3. +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { + seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) + seed = append(seed, serverRandom...) + seed = append(seed, clientRandom...) + + n := 2*macLen + 2*keyLen + 2*ivLen + keyMaterial := make([]byte, n) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) + clientMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + serverMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + clientKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + serverKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + clientIV = keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + serverIV = keyMaterial[:ivLen] + return +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + var buffer []byte + if version >= VersionTLS12 { + buffer = []byte{} + } + + prf, hash := prfAndHashForVersion(version, cipherSuite) + if hash != 0 { + return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} + } + + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +func (h *finishedHash) Write(msg []byte) (n int, err error) { + h.client.Write(msg) + h.server.Write(msg) + + if h.version < VersionTLS12 { + h.clientMD5.Write(msg) + h.serverMD5.Write(msg) + } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + + return len(msg), nil +} + +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + +// clientSum returns the contents of the verify_data member of a client's +// Finished message. +func (h finishedHash) clientSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) + return out +} + +// serverSum returns the contents of the verify_data member of a server's +// Finished message. +func (h finishedHash) serverSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) + return out +} + +// hashForClientCertificate returns the handshake messages so far, pre-hashed if +// necessary, suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { + if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { + panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") + } + + if sigType == signatureEd25519 { + return h.buffer + } + + if h.version >= VersionTLS12 { + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil) + } + + if sigType == signatureECDSA { + return h.server.Sum(nil) + } + + return h.Sum() +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil +} + +// noExportedKeyingMaterial is used as a value of +// ConnectionState.ekm when renegotiation is enabled and thus +// we wish to fail all key-material export requests. +func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") +} + +// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. +func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { + return func(label string, context []byte, length int) ([]byte, error) { + switch label { + case "client finished", "server finished", "master secret", "key expansion": + // These values are reserved and may not be used. + return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) + } + + seedLen := len(serverRandom) + len(clientRandom) + if context != nil { + seedLen += 2 + len(context) + } + seed := make([]byte, 0, seedLen) + + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + if context != nil { + if len(context) >= 1<<16 { + return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") + } + seed = append(seed, byte(len(context)>>8), byte(len(context))) + seed = append(seed, context...) + } + + keyMaterial := make([]byte, length) + prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) + return keyMaterial, nil + } +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/ticket.go b/vendor/github.com/marten-seemann/qtls-go1-17/ticket.go new file mode 100644 index 00000000..81e8a52e --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/ticket.go @@ -0,0 +1,274 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/binary" + "errors" + "io" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +// sessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type sessionState struct { + vers uint16 + cipherSuite uint16 + createdAt uint64 + masterSecret []byte // opaque master_secret<1..2^16-1>; + // struct { opaque certificate<1..2^24-1> } Certificate; + certificates [][]byte // Certificate certificate_list<0..2^24-1>; + + // usedOldKey is true if the ticket from which this session came from + // was encrypted with an older key and thus should be refreshed. + usedOldKey bool +} + +func (m *sessionState) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(m.vers) + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.masterSecret) + }) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for _, cert := range m.certificates { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + } + }) + return b.BytesOrPanic() +} + +func (m *sessionState) unmarshal(data []byte) bool { + *m = sessionState{usedOldKey: m.usedOldKey} + s := cryptobyte.String(data) + if ok := s.ReadUint16(&m.vers) && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint16LengthPrefixed(&s, &m.masterSecret) && + len(m.masterSecret) != 0; !ok { + return false + } + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + if !readUint24LengthPrefixed(&certList, &cert) { + return false + } + m.certificates = append(m.certificates, cert) + } + return s.Empty() +} + +// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first +// version (revision = 0) doesn't carry any of the information needed for 0-RTT +// validation and the nonce is always empty. +// version (revision = 1) carries the max_early_data_size sent in the ticket. +// version (revision = 2) carries the ALPN sent in the ticket. +type sessionStateTLS13 struct { + // uint8 version = 0x0304; + // uint8 revision = 2; + cipherSuite uint16 + createdAt uint64 + resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; + certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; + maxEarlyData uint32 + alpn string + + appData []byte +} + +func (m *sessionStateTLS13) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(VersionTLS13) + b.AddUint8(2) // revision + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.resumptionSecret) + }) + marshalCertificate(&b, m.certificate) + b.AddUint32(m.maxEarlyData) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpn)) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.appData) + }) + return b.BytesOrPanic() +} + +func (m *sessionStateTLS13) unmarshal(data []byte) bool { + *m = sessionStateTLS13{} + s := cryptobyte.String(data) + var version uint16 + var revision uint8 + var alpn []byte + ret := s.ReadUint16(&version) && + version == VersionTLS13 && + s.ReadUint8(&revision) && + revision == 2 && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint8LengthPrefixed(&s, &m.resumptionSecret) && + len(m.resumptionSecret) != 0 && + unmarshalCertificate(&s, &m.certificate) && + s.ReadUint32(&m.maxEarlyData) && + readUint8LengthPrefixed(&s, &alpn) && + readUint16LengthPrefixed(&s, &m.appData) && + s.Empty() + m.alpn = string(alpn) + return ret +} + +func (c *Conn) encryptTicket(state []byte) ([]byte, error) { + if len(c.ticketKeys) == 0 { + return nil, errors.New("tls: internal error: session ticket keys unavailable") + } + + encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + return nil, err + } + key := c.ticketKeys[0] + copy(keyName, key.keyName[:]) + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { + if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { + return nil, false + } + + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] + + keyIndex := -1 + for i, candidateKey := range c.ticketKeys { + if bytes.Equal(keyName, candidateKey.keyName[:]) { + keyIndex = i + break + } + } + if keyIndex == -1 { + return nil, false + } + key := &c.ticketKeys[keyIndex] + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + expected := mac.Sum(nil) + + if subtle.ConstantTimeCompare(macBytes, expected) != 1 { + return nil, false + } + + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, false + } + plaintext = make([]byte, len(ciphertext)) + cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) + + return plaintext, keyIndex > 0 +} + +func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) { + m := new(newSessionTicketMsgTLS13) + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionStateTLS13{ + cipherSuite: c.cipherSuite, + createdAt: uint64(c.config.time().Unix()), + resumptionSecret: c.resumptionSecret, + certificate: Certificate{ + Certificate: certsFromClient, + OCSPStaple: c.ocspResponse, + SignedCertificateTimestamps: c.scts, + }, + appData: appData, + alpn: c.clientProtocol, + } + if c.extraConfig != nil { + state.maxEarlyData = c.extraConfig.MaxEarlyData + } + var err error + m.label, err = c.encryptTicket(state.marshal()) + if err != nil { + return nil, err + } + m.lifetime = uint32(maxSessionTicketLifetime / time.Second) + + // ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1 + // The value is not stored anywhere; we never need to check the ticket age + // because 0-RTT is not supported. + ageAdd := make([]byte, 4) + _, err = c.config.rand().Read(ageAdd) + if err != nil { + return nil, err + } + m.ageAdd = binary.LittleEndian.Uint32(ageAdd) + + // ticket_nonce, which must be unique per connection, is always left at + // zero because we only ever send one ticket per connection. + + if c.extraConfig != nil { + m.maxEarlyData = c.extraConfig.MaxEarlyData + } + return m, nil +} + +// GetSessionTicket generates a new session ticket. +// It should only be called after the handshake completes. +// It can only be used for servers, and only if the alternative record layer is set. +// The ticket may be nil if config.SessionTicketsDisabled is set, +// or if the client isn't able to receive session tickets. +func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { + if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.") + } + if c.config.SessionTicketsDisabled { + return nil, nil + } + + m, err := c.getSessionTicketMsg(appData) + if err != nil { + return nil, err + } + return m.marshal(), nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/tls.go b/vendor/github.com/marten-seemann/qtls-go1-17/tls.go new file mode 100644 index 00000000..42207c23 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/tls.go @@ -0,0 +1,362 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// package qtls partially implements TLS 1.2, as specified in RFC 5246, +// and TLS 1.3, as specified in RFC 8446. +package qtls + +// BUG(agl): The crypto/tls package only implements some countermeasures +// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 +// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net" + "os" + "strings" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + c := &Conn{ + conn: conn, + config: fromConfig(config), + extraConfig: extraConfig, + } + c.handshakeFn = c.serverHandshake + return c +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + c := &Conn{ + conn: conn, + config: fromConfig(config), + extraConfig: extraConfig, + isClient: true, + } + c.handshakeFn = c.clientHandshake + return c +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type listener struct { + net.Listener + config *Config + extraConfig *ExtraConfig +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection is of type *Conn. +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return Server(c, l.config, l.extraConfig), nil +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener { + l := new(listener) + l.Listener = inner + l.config = config + l.extraConfig = extraConfig + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) { + if config == nil || len(config.Certificates) == 0 && + config.GetCertificate == nil && config.GetConfigForClient == nil { + return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config, extraConfig), nil +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +// +// DialWithDialer uses context.Background internally; to specify the context, +// use Dialer.DialContext with NetDialer set to the desired dialer. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + return dial(context.Background(), dialer, network, addr, config, extraConfig) +} + +func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + if netDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) + defer cancel() + } + + if !netDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) + defer cancel() + } + + rawConn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = defaultConfig() + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := config.Clone() + c.ServerName = hostname + config = c + } + + conn := Client(rawConn, config, extraConfig) + if err := conn.HandshakeContext(ctx); err != nil { + rawConn.Close() + return nil, err + } + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig) +} + +// Dialer dials TLS connections given a configuration and a Dialer for the +// underlying connection. +type Dialer struct { + // NetDialer is the optional dialer to use for the TLS connections' + // underlying TCP connections. + // A nil NetDialer is equivalent to the net.Dialer zero value. + NetDialer *net.Dialer + + // Config is the TLS configuration to use for new connections. + // A nil configuration is equivalent to the zero + // configuration; see the documentation of Config for the + // defaults. + Config *Config + + ExtraConfig *ExtraConfig +} + +// Dial connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The returned Conn, if any, will always be of type *Conn. +// +// Dial uses context.Background internally; to specify the context, +// use DialContext. +func (d *Dialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} + +func (d *Dialer) netDialer() *net.Dialer { + if d.NetDialer != nil { + return d.NetDialer + } + return new(net.Dialer) +} + +// DialContext connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig) + if err != nil { + // Don't return c (a typed nil) in an interface. + return nil, err + } + return c, nil +} + +// LoadX509KeyPair reads and parses a public/private key pair from a pair +// of files. The files must contain PEM encoded data. The certificate file +// may contain intermediate certificates following the leaf certificate to +// form a certificate chain. On successful return, Certificate.Leaf will +// be nil because the parsed form of the certificate is not retained. +func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { + certPEMBlock, err := os.ReadFile(certFile) + if err != nil { + return Certificate{}, err + } + keyPEMBlock, err := os.ReadFile(keyFile) + if err != nil { + return Certificate{}, err + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} + +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. On successful return, Certificate.Leaf will be nil because +// the parsed form of the certificate is not retained. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + fail := func(err error) (Certificate, error) { return Certificate{}, err } + + var cert Certificate + var skippedBlockTypes []string + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } else { + skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) + } + } + + if len(cert.Certificate) == 0 { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in certificate input")) + } + if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { + return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) + } + return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + + skippedBlockTypes = skippedBlockTypes[:0] + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) + if keyDERBlock == nil { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in key input")) + } + if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { + return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) + } + return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) + } + + // We don't need to parse the public key for TLS, but we so do anyway + // to check that it looks sane and matches the private key. + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fail(err) + } + + cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) + if err != nil { + return fail(err) + } + + switch pub := x509Cert.PublicKey.(type) { + case *rsa.PublicKey: + priv, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.N.Cmp(priv.N) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case *ecdsa.PublicKey: + priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case ed25519.PublicKey: + priv, ok := cert.PrivateKey.(ed25519.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { + return fail(errors.New("tls: private key does not match public key")) + } + default: + return fail(errors.New("tls: unknown public key algorithm")) + } + + return cert, nil +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("tls: failed to parse private key") +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/unsafe.go b/vendor/github.com/marten-seemann/qtls-go1-17/unsafe.go new file mode 100644 index 00000000..55fa01b3 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-17/unsafe.go @@ -0,0 +1,96 @@ +package qtls + +import ( + "crypto/tls" + "reflect" + "unsafe" +) + +func init() { + if !structsEqual(&tls.ConnectionState{}, &connectionState{}) { + panic("qtls.ConnectionState doesn't match") + } + if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) { + panic("qtls.ClientSessionState doesn't match") + } + if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) { + panic("qtls.CertificateRequestInfo doesn't match") + } + if !structsEqual(&tls.Config{}, &config{}) { + panic("qtls.Config doesn't match") + } + if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { + panic("qtls.ClientHelloInfo doesn't match") + } +} + +func toConnectionState(c connectionState) ConnectionState { + return *(*ConnectionState)(unsafe.Pointer(&c)) +} + +func toClientSessionState(s *clientSessionState) *ClientSessionState { + return (*ClientSessionState)(unsafe.Pointer(s)) +} + +func fromClientSessionState(s *ClientSessionState) *clientSessionState { + return (*clientSessionState)(unsafe.Pointer(s)) +} + +func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo { + return (*CertificateRequestInfo)(unsafe.Pointer(i)) +} + +func toConfig(c *config) *Config { + return (*Config)(unsafe.Pointer(c)) +} + +func fromConfig(c *Config) *config { + return (*config)(unsafe.Pointer(c)) +} + +func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo { + return (*ClientHelloInfo)(unsafe.Pointer(chi)) +} + +func structsEqual(a, b interface{}) bool { + return compare(reflect.ValueOf(a), reflect.ValueOf(b)) +} + +func compare(a, b reflect.Value) bool { + sa := a.Elem() + sb := b.Elem() + if sa.NumField() != sb.NumField() { + return false + } + for i := 0; i < sa.NumField(); i++ { + fa := sa.Type().Field(i) + fb := sb.Type().Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + if fa.Type.Kind() != fb.Type.Kind() { + return false + } + if fa.Type.Kind() == reflect.Slice { + if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) { + return false + } + continue + } + return false + } + } + return true +} + +func compareStruct(a, b reflect.Type) bool { + if a.NumField() != b.NumField() { + return false + } + for i := 0; i < a.NumField(); i++ { + fa := a.Field(i) + fb := b.Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + return false + } + } + return true +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/LICENSE b/vendor/github.com/marten-seemann/qtls-go1-18/LICENSE new file mode 100644 index 00000000..6a66aea5 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/README.md b/vendor/github.com/marten-seemann/qtls-go1-18/README.md new file mode 100644 index 00000000..3e902212 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/README.md @@ -0,0 +1,6 @@ +# qtls + +[![Go Reference](https://pkg.go.dev/badge/github.com/marten-seemann/qtls-go1-17.svg)](https://pkg.go.dev/github.com/marten-seemann/qtls-go1-17) +[![.github/workflows/go-test.yml](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml/badge.svg)](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml) + +This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/alert.go b/vendor/github.com/marten-seemann/qtls-go1-18/alert.go new file mode 100644 index 00000000..3feac79b --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/alert.go @@ -0,0 +1,102 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import "strconv" + +type alert uint8 + +// Alert is a TLS alert +type Alert = alert + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertExportRestriction alert = 60 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertInappropriateFallback alert = 86 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 + alertMissingExtension alert = 109 + alertUnsupportedExtension alert = 110 + alertCertificateUnobtainable alert = 111 + alertUnrecognizedName alert = 112 + alertBadCertificateStatusResponse alert = 113 + alertBadCertificateHashValue alert = 114 + alertUnknownPSKIdentity alert = 115 + alertCertificateRequired alert = 116 + alertNoApplicationProtocol alert = 120 +) + +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertExportRestriction: "export restriction", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertInappropriateFallback: "inappropriate fallback", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", + alertMissingExtension: "missing extension", + alertUnsupportedExtension: "unsupported extension", + alertCertificateUnobtainable: "certificate unobtainable", + alertUnrecognizedName: "unrecognized name", + alertBadCertificateStatusResponse: "bad certificate status response", + alertBadCertificateHashValue: "bad certificate hash value", + alertUnknownPSKIdentity: "unknown PSK identity", + alertCertificateRequired: "certificate required", + alertNoApplicationProtocol: "no application protocol", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return "tls: " + s + } + return "tls: alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/auth.go b/vendor/github.com/marten-seemann/qtls-go1-18/auth.go new file mode 100644 index 00000000..1ef675fd --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/auth.go @@ -0,0 +1,289 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "errors" + "fmt" + "hash" + "io" +) + +// verifyHandshakeSignature verifies a signature against pre-hashed +// (if required) handshake contents. +func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) + } + if !ecdsa.VerifyASN1(pubKey, signed, sig) { + return errors.New("ECDSA verification failure") + } + case signatureEd25519: + pubKey, ok := pubkey.(ed25519.PublicKey) + if !ok { + return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) + } + if !ed25519.Verify(pubKey, signed, sig) { + return errors.New("Ed25519 verification failure") + } + case signaturePKCS1v15: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { + return err + } + case signatureRSAPSS: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { + return err + } + default: + return errors.New("internal error: unknown signature type") + } + return nil +} + +const ( + serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" + clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" +) + +var signaturePadding = []byte{ + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, +} + +// signedMessage returns the pre-hashed (if necessary) message to be signed by +// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. +func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { + if sigHash == directSigning { + b := &bytes.Buffer{} + b.Write(signaturePadding) + io.WriteString(b, context) + b.Write(transcript.Sum(nil)) + return b.Bytes() + } + h := sigHash.New() + h.Write(signaturePadding) + io.WriteString(h, context) + h.Write(transcript.Sum(nil)) + return h.Sum(nil) +} + +// typeAndHashFromSignatureScheme returns the corresponding signature type and +// crypto.Hash for a given TLS SignatureScheme. +func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { + switch signatureAlgorithm { + case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: + sigType = signaturePKCS1v15 + case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: + sigType = signatureRSAPSS + case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: + sigType = signatureECDSA + case Ed25519: + sigType = signatureEd25519 + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + switch signatureAlgorithm { + case PKCS1WithSHA1, ECDSAWithSHA1: + hash = crypto.SHA1 + case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: + hash = crypto.SHA256 + case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: + hash = crypto.SHA384 + case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: + hash = crypto.SHA512 + case Ed25519: + hash = directSigning + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + return sigType, hash, nil +} + +// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for +// a given public key used with TLS 1.0 and 1.1, before the introduction of +// signature algorithm negotiation. +func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { + switch pub.(type) { + case *rsa.PublicKey: + return signaturePKCS1v15, crypto.MD5SHA1, nil + case *ecdsa.PublicKey: + return signatureECDSA, crypto.SHA1, nil + case ed25519.PublicKey: + // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, + // but it requires holding on to a handshake transcript to do a + // full signature, and not even OpenSSL bothers with the + // complexity, so we can't even test it properly. + return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") + default: + return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) + } +} + +var rsaSignatureSchemes = []struct { + scheme SignatureScheme + minModulusBytes int + maxVersion uint16 +}{ + // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires + // emLen >= hLen + sLen + 2 + {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, + // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires + // emLen >= len(prefix) + hLen + 11 + // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. + {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, + {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, + {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, + {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, +} + +// signatureSchemesForCertificate returns the list of supported SignatureSchemes +// for a given certificate, based on the public key and the protocol version, +// and optionally filtered by its explicit SupportedSignatureAlgorithms. +// +// This function must be kept in sync with supportedSignatureAlgorithms. +func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil + } + + var sigAlgs []SignatureScheme + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + if version != VersionTLS13 { + // In TLS 1.2 and earlier, ECDSA algorithms are not + // constrained to a single curve. + sigAlgs = []SignatureScheme{ + ECDSAWithP256AndSHA256, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + ECDSAWithSHA1, + } + break + } + switch pub.Curve { + case elliptic.P256(): + sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} + case elliptic.P384(): + sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} + case elliptic.P521(): + sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} + default: + return nil + } + case *rsa.PublicKey: + size := pub.Size() + sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) + for _, candidate := range rsaSignatureSchemes { + if size >= candidate.minModulusBytes && version <= candidate.maxVersion { + sigAlgs = append(sigAlgs, candidate.scheme) + } + } + case ed25519.PublicKey: + sigAlgs = []SignatureScheme{Ed25519} + default: + return nil + } + + if cert.SupportedSignatureAlgorithms != nil { + var filteredSigAlgs []SignatureScheme + for _, sigAlg := range sigAlgs { + if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { + filteredSigAlgs = append(filteredSigAlgs, sigAlg) + } + } + return filteredSigAlgs + } + return sigAlgs +} + +// selectSignatureScheme picks a SignatureScheme from the peer's preference list +// that works with the selected certificate. It's only called for protocol +// versions that support signature algorithms, so TLS 1.2 and 1.3. +func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { + supportedAlgs := signatureSchemesForCertificate(vers, c) + if len(supportedAlgs) == 0 { + return 0, unsupportedCertificateError(c) + } + if len(peerAlgs) == 0 && vers == VersionTLS12 { + // For TLS 1.2, if the client didn't send signature_algorithms then we + // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. + peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} + } + // Pick signature scheme in the peer's preference order, as our + // preference order is not configurable. + for _, preferredAlg := range peerAlgs { + if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { + return preferredAlg, nil + } + } + return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") +} + +// unsupportedCertificateError returns a helpful error for certificates with +// an unsupported private key. +func unsupportedCertificateError(cert *Certificate) error { + switch cert.PrivateKey.(type) { + case rsa.PrivateKey, ecdsa.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", + cert.PrivateKey, cert.PrivateKey) + case *ed25519.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") + } + + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", + cert.PrivateKey) + } + + switch pub := signer.Public().(type) { + case *ecdsa.PublicKey: + switch pub.Curve { + case elliptic.P256(): + case elliptic.P384(): + case elliptic.P521(): + default: + return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) + } + case *rsa.PublicKey: + return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") + case ed25519.PublicKey: + default: + return fmt.Errorf("tls: unsupported certificate key (%T)", pub) + } + + if cert.SupportedSignatureAlgorithms != nil { + return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") + } + + return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/cipher_suites.go b/vendor/github.com/marten-seemann/qtls-go1-18/cipher_suites.go new file mode 100644 index 00000000..e0be5147 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/cipher_suites.go @@ -0,0 +1,691 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/hmac" + "crypto/rc4" + "crypto/sha1" + "crypto/sha256" + "fmt" + "hash" + + "golang.org/x/crypto/chacha20poly1305" +) + +// CipherSuite is a TLS cipher suite. Note that most functions in this package +// accept and expose cipher suite IDs instead of this type. +type CipherSuite struct { + ID uint16 + Name string + + // Supported versions is the list of TLS protocol versions that can + // negotiate this cipher suite. + SupportedVersions []uint16 + + // Insecure is true if the cipher suite has known security issues + // due to its primitives, design, or implementation. + Insecure bool +} + +var ( + supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} + supportedOnlyTLS12 = []uint16{VersionTLS12} + supportedOnlyTLS13 = []uint16{VersionTLS13} +) + +// CipherSuites returns a list of cipher suites currently implemented by this +// package, excluding those with security issues, which are returned by +// InsecureCipherSuites. +// +// The list is sorted by ID. Note that the default cipher suites selected by +// this package might depend on logic that can't be captured by a static list, +// and might not match those returned by this function. +func CipherSuites() []*CipherSuite { + return []*CipherSuite{ + {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + + {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, + {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, + {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, + + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + } +} + +// InsecureCipherSuites returns a list of cipher suites currently implemented by +// this package and which have security issues. +// +// Most applications should not use the cipher suites in this list, and should +// only use those returned by CipherSuites. +func InsecureCipherSuites() []*CipherSuite { + // This list includes RC4, CBC_SHA256, and 3DES cipher suites. See + // cipherSuitesPreferenceOrder for details. + return []*CipherSuite{ + {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + } +} + +// CipherSuiteName returns the standard name for the passed cipher suite ID +// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation +// of the ID value if the cipher suite is not implemented by this package. +func CipherSuiteName(id uint16) string { + for _, c := range CipherSuites() { + if c.ID == id { + return c.Name + } + } + for _, c := range InsecureCipherSuites() { + if c.ID == id { + return c.Name + } + } + return fmt.Sprintf("0x%04X", id) +} + +const ( + // suiteECDHE indicates that the cipher suite involves elliptic curve + // Diffie-Hellman. This means that it should only be selected when the + // client indicates that it supports ECC with a curve and point format + // that we're happy with. + suiteECDHE = 1 << iota + // suiteECSign indicates that the cipher suite involves an ECDSA or + // EdDSA signature and therefore may only be selected when the server's + // certificate is ECDSA or EdDSA. If this is not set then the cipher suite + // is RSA based. + suiteECSign + // suiteTLS12 indicates that the cipher suite should only be advertised + // and accepted when using TLS 1.2. + suiteTLS12 + // suiteSHA384 indicates that the cipher suite uses SHA384 as the + // handshake hash. + suiteSHA384 +) + +// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange +// mechanism, as well as the cipher+MAC pair or the AEAD. +type cipherSuite struct { + id uint16 + // the lengths, in bytes, of the key material needed for each component. + keyLen int + macLen int + ivLen int + ka func(version uint16) keyAgreement + // flags is a bitmask of the suite* values, above. + flags int + cipher func(key, iv []byte, isRead bool) any + mac func(key []byte) hash.Hash + aead func(key, fixedNonce []byte) aead +} + +var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter. + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil}, +} + +// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which +// is also in supportedIDs and passes the ok filter. +func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { + for _, id := range ids { + candidate := cipherSuiteByID(id) + if candidate == nil || !ok(candidate) { + continue + } + + for _, suppID := range supportedIDs { + if id == suppID { + return candidate + } + } + } + return nil +} + +// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash +// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. +type cipherSuiteTLS13 struct { + id uint16 + keyLen int + aead func(key, fixedNonce []byte) aead + hash crypto.Hash +} + +type CipherSuiteTLS13 struct { + ID uint16 + KeyLen int + Hash crypto.Hash + AEAD func(key, fixedNonce []byte) cipher.AEAD +} + +func (c *CipherSuiteTLS13) IVLen() int { + return aeadNonceLength +} + +var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map. + {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, + {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, + {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, +} + +// cipherSuitesPreferenceOrder is the order in which we'll select (on the +// server) or advertise (on the client) TLS 1.0–1.2 cipher suites. +// +// Cipher suites are filtered but not reordered based on the application and +// peer's preferences, meaning we'll never select a suite lower in this list if +// any higher one is available. This makes it more defensible to keep weaker +// cipher suites enabled, especially on the server side where we get the last +// word, since there are no known downgrade attacks on cipher suites selection. +// +// The list is sorted by applying the following priority rules, stopping at the +// first (most important) applicable one: +// +// - Anything else comes before RC4 +// +// RC4 has practically exploitable biases. See https://www.rc4nomore.com. +// +// - Anything else comes before CBC_SHA256 +// +// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 +// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. +// +// - Anything else comes before 3DES +// +// 3DES has 64-bit blocks, which makes it fundamentally susceptible to +// birthday attacks. See https://sweet32.info. +// +// - ECDHE comes before anything else +// +// Once we got the broken stuff out of the way, the most important +// property a cipher suite can have is forward secrecy. We don't +// implement FFDHE, so that means ECDHE. +// +// - AEADs come before CBC ciphers +// +// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites +// are fundamentally fragile, and suffered from an endless sequence of +// padding oracle attacks. See https://eprint.iacr.org/2015/1129, +// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and +// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. +// +// - AES comes before ChaCha20 +// +// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster +// than ChaCha20Poly1305. +// +// When AES hardware is not available, AES-128-GCM is one or more of: much +// slower, way more complex, and less safe (because not constant time) +// than ChaCha20Poly1305. +// +// We use this list if we think both peers have AES hardware, and +// cipherSuitesPreferenceOrderNoAES otherwise. +// +// - AES-128 comes before AES-256 +// +// The only potential advantages of AES-256 are better multi-target +// margins, and hypothetical post-quantum properties. Neither apply to +// TLS, and AES-256 is slower due to its four extra rounds (which don't +// contribute to the advantages above). +// +// - ECDSA comes before RSA +// +// The relative order of ECDSA and RSA cipher suites doesn't matter, +// as they depend on the certificate. Pick one to get a stable order. +// +var cipherSuitesPreferenceOrder = []uint16{ + // AEADs w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + + // CBC w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + + // AEADs w/o ECDHE + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + + // CBC w/o ECDHE + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + + // 3DES + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_RSA_WITH_3DES_EDE_CBC_SHA, + + // CBC_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + + // RC4 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +var cipherSuitesPreferenceOrderNoAES = []uint16{ + // ChaCha20Poly1305 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + + // AES-GCM w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + + // The rest of cipherSuitesPreferenceOrder. + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +// disabledCipherSuites are not used unless explicitly listed in +// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder. +var disabledCipherSuites = []uint16{ + // CBC_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + + // RC4 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +var ( + defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) + defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen] +) + +// defaultCipherSuitesTLS13 is also the preference order, since there are no +// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as +// cipherSuitesPreferenceOrder applies. +var defaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, +} + +var defaultCipherSuitesTLS13NoAES = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, +} + +var aesgcmCiphers = map[uint16]bool{ + // TLS 1.2 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, + // TLS 1.3 + TLS_AES_128_GCM_SHA256: true, + TLS_AES_256_GCM_SHA384: true, +} + +var nonAESGCMAEADCiphers = map[uint16]bool{ + // TLS 1.2 + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, + // TLS 1.3 + TLS_CHACHA20_POLY1305_SHA256: true, +} + +// aesgcmPreferred returns whether the first known cipher in the preference list +// is an AES-GCM cipher, implying the peer has hardware support for it. +func aesgcmPreferred(ciphers []uint16) bool { + for _, cID := range ciphers { + if c := cipherSuiteByID(cID); c != nil { + return aesgcmCiphers[cID] + } + if c := cipherSuiteTLS13ByID(cID); c != nil { + return aesgcmCiphers[cID] + } + } + return false +} + +func cipherRC4(key, iv []byte, isRead bool) any { + cipher, _ := rc4.NewCipher(key) + return cipher +} + +func cipher3DES(key, iv []byte, isRead bool) any { + block, _ := des.NewTripleDESCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +func cipherAES(key, iv []byte, isRead bool) any { + block, _ := aes.NewCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +// macSHA1 returns a SHA-1 based constant time MAC. +func macSHA1(key []byte) hash.Hash { + return hmac.New(newConstantTimeHash(sha1.New), key) +} + +// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and +// is currently only used in disabled-by-default cipher suites. +func macSHA256(key []byte) hash.Hash { + return hmac.New(sha256.New, key) +} + +type aead interface { + cipher.AEAD + + // explicitNonceLen returns the number of bytes of explicit nonce + // included in each record. This is eight for older AEADs and + // zero for modern ones. + explicitNonceLen() int +} + +const ( + aeadNonceLength = 12 + noncePrefixLength = 4 +) + +// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +// each call. +type prefixNonceAEAD struct { + // nonce contains the fixed part of the nonce in the first four bytes. + nonce [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } +func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } + +func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + copy(f.nonce[4:], nonce) + return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) +} + +func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + copy(f.nonce[4:], nonce) + return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) +} + +// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// before each call. +type xorNonceAEAD struct { + nonceMask [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result, err +} + +func aeadAESGCM(key, noncePrefix []byte) aead { + if len(noncePrefix) != noncePrefixLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &prefixNonceAEAD{aead: aead} + copy(ret.nonce[:], noncePrefix) + return ret +} + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return aeadAESGCMTLS13(key, fixedNonce) +} + +func aeadAESGCMTLS13(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +func aeadChaCha20Poly1305(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + +// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. +func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { + h.Reset() + h.Write(seq) + h.Write(header) + h.Write(data) + res := h.Sum(out) + if extra != nil { + h.Write(extra) + } + return res +} + +func rsaKA(version uint16) keyAgreement { + return rsaKeyAgreement{} +} + +func ecdheECDSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: false, + version: version, + } +} + +func ecdheRSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: true, + version: version, + } +} + +// mutualCipherSuite returns a cipherSuite given a list of supported +// ciphersuites and the id requested by the peer. +func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { + for _, id := range have { + if id == want { + return cipherSuiteByID(id) + } + } + return nil +} + +func cipherSuiteByID(id uint16) *cipherSuite { + for _, cipherSuite := range cipherSuites { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { + for _, id := range have { + if id == want { + return cipherSuiteTLS13ByID(id) + } + } + return nil +} + +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { + for _, cipherSuite := range cipherSuitesTLS13 { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +// A list of cipher suite IDs that are, or have been, implemented by this +// package. +// +// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml +const ( + // TLS 1.0 - 1.2 cipher suites. + TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 + + // TLS 1.3 cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = 0x1301 + TLS_AES_256_GCM_SHA384 uint16 = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 + + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator + // that the client is doing version fallback. See RFC 7507. + TLS_FALLBACK_SCSV uint16 = 0x5600 + + // Legacy names for the corresponding cipher suites with the correct _SHA256 + // suffix, retained for backward compatibility. + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/common.go b/vendor/github.com/marten-seemann/qtls-go1-18/common.go new file mode 100644 index 00000000..4c9aeeb4 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/common.go @@ -0,0 +1,1507 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "container/list" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "strings" + "sync" + "time" +) + +const ( + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 + VersionTLS13 = 0x0304 + + // Deprecated: SSLv3 is cryptographically broken, and is no longer + // supported by this package. See golang.org/issue/32716. + VersionSSL30 = 0x0300 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxUselessRecords = 16 // maximum number of consecutive non-advancing records +) + +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 + typeNextProtocol uint8 = 67 // Not IANA assigned + typeMessageHash uint8 = 254 // synthetic message +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +type Extension struct { + Type uint16 + Data []byte +} + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionRenegotiationInfo uint16 = 0xff01 +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +type EncryptionLevel uint8 + +const ( + EncryptionHandshake EncryptionLevel = iota + Encryption0RTT + EncryptionApplication +) + +// CurveID is a tls.CurveID +type CurveID = tls.CurveID + +const ( + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + X25519 CurveID = 29 +) + +// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. +type keyShare struct { + group CurveID + data []byte +} + +// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. +const ( + pskModePlain uint8 = 0 + pskModeDHE uint8 = 1 +) + +// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved +// session. See RFC 8446, Section 4.2.11. +type pskIdentity struct { + label []byte + obfuscatedTicketAge uint32 +} + +// TLS Elliptic Curve Point Formats +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + certTypeRSASign = 1 + certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. +) + +// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with +// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. +const ( + signaturePKCS1v15 uint8 = iota + 225 + signatureRSAPSS + signatureECDSA + signatureEd25519 +) + +// directSigning is a standard Hash value that signals that no pre-hashing +// should be performed, and that the input should be signed directly. It is the +// hash function associated with the Ed25519 signature scheme. +var directSigning crypto.Hash = 0 + +// supportedSignatureAlgorithms contains the signature and hash algorithms that +// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ +// CertificateRequest. The two fields are merged to match with TLS 1.3. +// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. +var supportedSignatureAlgorithms = []SignatureScheme{ + PSSWithSHA256, + ECDSAWithP256AndSHA256, + Ed25519, + PSSWithSHA384, + PSSWithSHA512, + PKCS1WithSHA256, + PKCS1WithSHA384, + PKCS1WithSHA512, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + PKCS1WithSHA1, + ECDSAWithSHA1, +} + +// helloRetryRequestRandom is set as the Random value of a ServerHello +// to signal that the message is actually a HelloRetryRequest. +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + +const ( + // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server + // random as a downgrade protection if the server would be capable of + // negotiating a higher version. See RFC 8446, Section 4.1.3. + downgradeCanaryTLS12 = "DOWNGRD\x01" + downgradeCanaryTLS11 = "DOWNGRD\x00" +) + +// testingOnlyForceDowngradeCanary is set in tests to force the server side to +// include downgrade canaries even if it's using its highers supported version. +var testingOnlyForceDowngradeCanary bool + +type ConnectionState = tls.ConnectionState + +// ConnectionState records basic TLS details about the connection. +type connectionState struct { + // Version is the TLS version used by the connection (e.g. VersionTLS12). + Version uint16 + + // HandshakeComplete is true if the handshake has concluded. + HandshakeComplete bool + + // DidResume is true if this connection was successfully resumed from a + // previous session with a session ticket or similar mechanism. + DidResume bool + + // CipherSuite is the cipher suite negotiated for the connection (e.g. + // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). + CipherSuite uint16 + + // NegotiatedProtocol is the application protocol negotiated with ALPN. + NegotiatedProtocol string + + // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. + // + // Deprecated: this value is always true. + NegotiatedProtocolIsMutual bool + + // ServerName is the value of the Server Name Indication extension sent by + // the client. It's available both on the server and on the client side. + ServerName string + + // PeerCertificates are the parsed certificates sent by the peer, in the + // order in which they were sent. The first element is the leaf certificate + // that the connection is verified against. + // + // On the client side, it can't be empty. On the server side, it can be + // empty if Config.ClientAuth is not RequireAnyClientCert or + // RequireAndVerifyClientCert. + PeerCertificates []*x509.Certificate + + // VerifiedChains is a list of one or more chains where the first element is + // PeerCertificates[0] and the last element is from Config.RootCAs (on the + // client side) or Config.ClientCAs (on the server side). + // + // On the client side, it's set if Config.InsecureSkipVerify is false. On + // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven + // (and the peer provided a certificate) or RequireAndVerifyClientCert. + VerifiedChains [][]*x509.Certificate + + // SignedCertificateTimestamps is a list of SCTs provided by the peer + // through the TLS handshake for the leaf certificate, if any. + SignedCertificateTimestamps [][]byte + + // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) + // response provided by the peer for the leaf certificate, if any. + OCSPResponse []byte + + // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, + // Section 3). This value will be nil for TLS 1.3 connections and for all + // resumed connections. + // + // Deprecated: there are conditions in which this value might not be unique + // to a connection. See the Security Considerations sections of RFC 5705 and + // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. + TLSUnique []byte + + // ekm is a closure exposed via ExportKeyingMaterial. + ekm func(label string, context []byte, length int) ([]byte, error) +} + +type ConnectionStateWith0RTT struct { + ConnectionState + + Used0RTT bool // true if 0-RTT was both offered and accepted +} + +// ClientAuthType is tls.ClientAuthType +type ClientAuthType = tls.ClientAuthType + +const ( + NoClientCert = tls.NoClientCert + RequestClientCert = tls.RequestClientCert + RequireAnyClientCert = tls.RequireAnyClientCert + VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven + RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert +) + +// requiresClientCert reports whether the ClientAuthType requires a client +// certificate to be provided. +func requiresClientCert(c ClientAuthType) bool { + switch c { + case RequireAnyClientCert, RequireAndVerifyClientCert: + return true + default: + return false + } +} + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState = tls.ClientSessionState + +type clientSessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification + receivedAt time.Time // When the session ticket was received from the server + ocspResponse []byte // Stapled OCSP response presented by the server + scts [][]byte // SCTs presented by the server + + // TLS 1.3 fields. + nonce []byte // Ticket nonce sent by the server, to derive PSK + useBy time.Time // Expiration of the ticket lifetime as set by the server + ageAdd uint32 // Random obfuscation factor for sending the ticket age +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not +// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which +// are supported via this interface. +//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/marten-seemann/qtls-go1-17 ClientSessionCache" +type ClientSessionCache = tls.ClientSessionCache + +// SignatureScheme is a tls.SignatureScheme +type SignatureScheme = tls.SignatureScheme + +const ( + // RSASSA-PKCS1-v1_5 algorithms. + PKCS1WithSHA256 SignatureScheme = 0x0401 + PKCS1WithSHA384 SignatureScheme = 0x0501 + PKCS1WithSHA512 SignatureScheme = 0x0601 + + // RSASSA-PSS algorithms with public key OID rsaEncryption. + PSSWithSHA256 SignatureScheme = 0x0804 + PSSWithSHA384 SignatureScheme = 0x0805 + PSSWithSHA512 SignatureScheme = 0x0806 + + // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. + ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 + ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 + ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 + + // EdDSA algorithms. + Ed25519 SignatureScheme = 0x0807 + + // Legacy signature and hash algorithms for TLS 1.2. + PKCS1WithSHA1 SignatureScheme = 0x0201 + ECDSAWithSHA1 SignatureScheme = 0x0203 +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide application logic in the GetCertificate and GetConfigForClient callbacks. +type ClientHelloInfo = tls.ClientHelloInfo + +type clientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see RFC 4366, Section 3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see RFC 4492, Section 5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see RFC 4492, Section 5.1.2). + SupportedPoints []uint8 + + // SignatureSchemes lists the signature and hash schemes that the client + // is willing to verify. SignatureSchemes is set only if the Signature + // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). + SignatureSchemes []SignatureScheme + + // SupportedProtos lists the application protocols supported by the client. + // SupportedProtos is set only if the Application-Layer Protocol + // Negotiation Extension is being used (see RFC 7301, Section 3.1). + // + // Servers can select a protocol by setting Config.NextProtos in a + // GetConfigForClient return value. + SupportedProtos []string + + // SupportedVersions lists the TLS versions supported by the client. + // For TLS versions less than 1.3, this is extrapolated from the max + // version advertised by the client, so values other than the greatest + // might be rejected if used. + SupportedVersions []uint16 + + // Conn is the underlying net.Conn for the connection. Do not read + // from, or write to, this connection; that will cause the TLS + // connection to fail. + Conn net.Conn + + // config is embedded by the GetCertificate or GetConfigForClient caller, + // for use with SupportsCertificate. + config *Config + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *clientHelloInfo) Context() context.Context { + return c.ctx +} + +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo = tls.CertificateRequestInfo + +type certificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte + + // SignatureSchemes lists the signature schemes that the server is + // willing to verify. + SignatureSchemes []SignatureScheme + + // Version is the TLS version that was negotiated for this connection. + Version uint16 + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *certificateRequestInfo) Context() context.Context { + return c.ctx +} + +// RenegotiationSupport enumerates the different levels of support for TLS +// renegotiation. TLS renegotiation is the act of performing subsequent +// handshakes on a connection after the first. This significantly complicates +// the state machine and has been the source of numerous, subtle security +// issues. Initiating a renegotiation is not supported, but support for +// accepting renegotiation requests may be enabled. +// +// Even when enabled, the server may not change its identity between handshakes +// (i.e. the leaf certificate must be the same). Additionally, concurrent +// handshake and application data flow is not permitted so renegotiation can +// only be used with protocols that synchronise with the renegotiation, such as +// HTTPS. +// +// Renegotiation is not defined in TLS 1.3. +type RenegotiationSupport = tls.RenegotiationSupport + +const ( + // RenegotiateNever disables renegotiation. + RenegotiateNever = tls.RenegotiateNever + + // RenegotiateOnceAsClient allows a remote server to request + // renegotiation once per connection. + RenegotiateOnceAsClient = tls.RenegotiateOnceAsClient + + // RenegotiateFreelyAsClient allows a remote server to repeatedly + // request renegotiation. + RenegotiateFreelyAsClient = tls.RenegotiateFreelyAsClient +) + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config = tls.Config + +type config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains to present to the + // other side of the connection. The first certificate compatible with the + // peer's requirements is selected automatically. + // + // Server configurations must set one of Certificates, GetCertificate or + // GetConfigForClient. Clients doing client-authentication may set either + // Certificates or GetClientCertificate. + // + // Note: if there are multiple Certificates, and they don't have the + // optional field Leaf set, certificate selection will incur a significant + // per-handshake performance cost. + Certificates []Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // + // Deprecated: NameToCertificate only allows associating a single + // certificate with a given name. Leave this field nil to let the library + // select the first compatible chain from Certificates. + NameToCertificate map[string]*Certificate + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // best element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // If SessionTicketKey was explicitly set on the returned Config, or if + // SetSessionTicketKeys was called on the returned Config, those keys will + // be used. Otherwise, the original Config keys will be used (and possibly + // rotated if they are automatically managed). + GetConfigForClient func(*ClientHelloInfo) (*Config, error) + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // VerifyConnection, if not nil, is called after normal certificate + // verification and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(ConnectionState) error + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported application level protocols, in + // order of preference. If both peers support ALPN, the selected + // protocol will be one from this list, and the connection will fail + // if there is no mutually supported protocol. If NextProtos is empty + // or the peer doesn't support ALPN, the connection will succeed and + // ConnectionState.NegotiatedProtocol will be empty. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting unless it is + // an IP address. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // InsecureSkipVerify controls whether a client verifies the server's + // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls + // accepts any certificate presented by the server and any host name in that + // certificate. In this mode, TLS is susceptible to machine-in-the-middle + // attacks unless custom verification is used. This should be used only for + // testing or in combination with VerifyConnection or VerifyPeerCertificate. + InsecureSkipVerify bool + + // CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of + // the list is ignored. Note that TLS 1.3 ciphersuites are not configurable. + // + // If CipherSuites is nil, a safe default list is used. The default cipher + // suites might change over time. + CipherSuites []uint16 + + // PreferServerCipherSuites is a legacy field and has no effect. + // + // It used to control whether the server would follow the client's or the + // server's preference. Servers now select the best mutually supported + // cipher suite based on logic that takes into account inferred client + // hardware, server hardware, and security. + // + // Deprecated: PreferServerCipherSuites is ignored. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket and + // PSK (resumption) support. Note that on clients, session ticket support is + // also disabled if ClientSessionCache is nil. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session resumption. + // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled + // with random data before the first server handshake. + // + // Deprecated: if this field is left at zero, session ticket keys will be + // automatically rotated every day and dropped after seven days. For + // customizing the rotation schedule or synchronizing servers that are + // terminating connections for the same host, use SetSessionTicketKeys. + SessionTicketKey [32]byte + + // ClientSessionCache is a cache of ClientSessionState entries for TLS + // session resumption. It is only used by clients. + ClientSessionCache ClientSessionCache + + // MinVersion contains the minimum TLS version that is acceptable. + // + // By default, TLS 1.2 is currently used as the minimum when acting as a + // client, and TLS 1.0 when acting as a server. TLS 1.0 is the minimum + // supported by this package, both as a client and as a server. + // + // The client-side default can temporarily be reverted to TLS 1.0 by + // including the value "x509sha1=1" in the GODEBUG environment variable. + // Note that this option will be removed in Go 1.19 (but it will still be + // possible to set this field to VersionTLS10 explicitly). + MinVersion uint16 + + // MaxVersion contains the maximum TLS version that is acceptable. + // + // By default, the maximum version supported by this package is used, + // which is currently TLS 1.3. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. The client will use the first preference as the type for + // its key share in TLS 1.3. This may change in the future. + CurvePreferences []CurveID + + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. + // When true, the largest possible TLS record size is always used. When + // false, the size of TLS records may be adjusted in an attempt to + // improve latency. + DynamicRecordSizingDisabled bool + + // Renegotiation controls what types of renegotiation are supported. + // The default, none, is correct for the vast majority of applications. + Renegotiation RenegotiationSupport + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // mutex protects sessionTicketKeys and autoSessionTicketKeys. + mutex sync.RWMutex + // sessionTicketKeys contains zero or more ticket keys. If set, it means the + // the keys were set with SessionTicketKey or SetSessionTicketKeys. The + // first key is used for new tickets and any subsequent keys can be used to + // decrypt old tickets. The slice contents are not protected by the mutex + // and are immutable. + sessionTicketKeys []ticketKey + // autoSessionTicketKeys is like sessionTicketKeys but is owned by the + // auto-rotation logic. See Config.ticketKeys. + autoSessionTicketKeys []ticketKey +} + +// A RecordLayer handles encrypting and decrypting of TLS messages. +type RecordLayer interface { + SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) + SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) + ReadHandshakeMessage() ([]byte, error) + WriteRecord([]byte) (int, error) + SendAlert(uint8) +} + +type ExtraConfig struct { + // GetExtensions, if not nil, is called before a message that allows + // sending of extensions is sent. + // Currently only implemented for the ClientHello message (for the client) + // and for the EncryptedExtensions message (for the server). + // Only valid for TLS 1.3. + GetExtensions func(handshakeMessageType uint8) []Extension + + // ReceivedExtensions, if not nil, is called when a message that allows the + // inclusion of extensions is received. + // It is called with an empty slice of extensions, if the message didn't + // contain any extensions. + // Currently only implemented for the ClientHello message (sent by the + // client) and for the EncryptedExtensions message (sent by the server). + // Only valid for TLS 1.3. + ReceivedExtensions func(handshakeMessageType uint8, exts []Extension) + + // AlternativeRecordLayer is used by QUIC + AlternativeRecordLayer RecordLayer + + // Enforce the selection of a supported application protocol. + // Only works for TLS 1.3. + // If enabled, client and server have to agree on an application protocol. + // Otherwise, connection establishment fails. + EnforceNextProtoSelection bool + + // If MaxEarlyData is greater than 0, the client will be allowed to send early + // data when resuming a session. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning on the client. + MaxEarlyData uint32 + + // The Accept0RTT callback is called when the client offers 0-RTT. + // The server then has to decide if it wants to accept or reject 0-RTT. + // It is only used for servers. + Accept0RTT func(appData []byte) bool + + // 0RTTRejected is called when the server rejectes 0-RTT. + // It is only used for clients. + Rejected0RTT func() + + // If set, the client will export the 0-RTT key when resuming a session that + // allows sending of early data. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning to the server. + Enable0RTT bool + + // Is called when the client saves a session ticket to the session ticket. + // This gives the application the opportunity to save some data along with the ticket, + // which can be restored when the session ticket is used. + GetAppDataForSessionState func() []byte + + // Is called when the client uses a session ticket. + // Restores the application data that was saved earlier on GetAppDataForSessionTicket. + SetAppDataFromSessionState func([]byte) +} + +// Clone clones. +func (c *ExtraConfig) Clone() *ExtraConfig { + return &ExtraConfig{ + GetExtensions: c.GetExtensions, + ReceivedExtensions: c.ReceivedExtensions, + AlternativeRecordLayer: c.AlternativeRecordLayer, + EnforceNextProtoSelection: c.EnforceNextProtoSelection, + MaxEarlyData: c.MaxEarlyData, + Enable0RTT: c.Enable0RTT, + Accept0RTT: c.Accept0RTT, + Rejected0RTT: c.Rejected0RTT, + GetAppDataForSessionState: c.GetAppDataForSessionState, + SetAppDataFromSessionState: c.SetAppDataFromSessionState, + } +} + +func (c *ExtraConfig) usesAlternativeRecordLayer() bool { + return c != nil && c.AlternativeRecordLayer != nil +} + +const ( + // ticketKeyNameLen is the number of bytes of identifier that is prepended to + // an encrypted session ticket in order to identify the key used to encrypt it. + ticketKeyNameLen = 16 + + // ticketKeyLifetime is how long a ticket key remains valid and can be used to + // resume a client connection. + ticketKeyLifetime = 7 * 24 * time.Hour // 7 days + + // ticketKeyRotation is how often the server should rotate the session ticket key + // that is used for new tickets. + ticketKeyRotation = 24 * time.Hour +) + +// ticketKey is the internal representation of a session ticket key. +type ticketKey struct { + // keyName is an opaque byte string that serves to identify the session + // ticket key. It's exposed as plaintext in every session ticket. + keyName [ticketKeyNameLen]byte + aesKey [16]byte + hmacKey [16]byte + // created is the time at which this ticket key was created. See Config.ticketKeys. + created time.Time +} + +// ticketKeyFromBytes converts from the external representation of a session +// ticket key to a ticketKey. Externally, session ticket keys are 32 random +// bytes and this function expands that into sufficient name and key material. +func (c *config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { + hashed := sha512.Sum512(b[:]) + copy(key.keyName[:], hashed[:ticketKeyNameLen]) + copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) + copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) + key.created = c.time() + return key +} + +// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session +// ticket, and the lifetime we set for tickets we send. +const maxSessionTicketLifetime = 7 * 24 * time.Hour + +// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *config) Clone() *config { + if c == nil { + return nil + } + c.mutex.RLock() + defer c.mutex.RUnlock() + return &config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + VerifyConnection: c.VerifyConnection, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + sessionTicketKeys: c.sessionTicketKeys, + autoSessionTicketKeys: c.autoSessionTicketKeys, + } +} + +// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was +// randomized for backwards compatibility but is not in use. +var deprecatedSessionTicketKey = []byte("DEPRECATED") + +// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is +// randomized if empty, and that sessionTicketKeys is populated from it otherwise. +func (c *config) initLegacySessionTicketKeyRLocked() { + // Don't write if SessionTicketKey is already defined as our deprecated string, + // or if it is defined by the user but sessionTicketKeys is already set. + if c.SessionTicketKey != [32]byte{} && + (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { + return + } + + // We need to write some data, so get an exclusive lock and re-check any conditions. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + if c.SessionTicketKey == [32]byte{} { + if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) + } + // Write the deprecated prefix at the beginning so we know we created + // it. This key with the DEPRECATED prefix isn't used as an actual + // session ticket key, and is only randomized in case the application + // reuses it for some reason. + copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) + } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { + c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} + } + +} + +// ticketKeys returns the ticketKeys for this connection. +// If configForClient has explicitly set keys, those will +// be returned. Otherwise, the keys on c will be used and +// may be rotated if auto-managed. +// During rotation, any expired session ticket keys are deleted from +// c.sessionTicketKeys. If the session ticket key that is currently +// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) +// is not fresh, then a new session ticket key will be +// created and prepended to c.sessionTicketKeys. +func (c *config) ticketKeys(configForClient *config) []ticketKey { + // If the ConfigForClient callback returned a Config with explicitly set + // keys, use those, otherwise just use the original Config. + if configForClient != nil { + configForClient.mutex.RLock() + if configForClient.SessionTicketsDisabled { + return nil + } + configForClient.initLegacySessionTicketKeyRLocked() + if len(configForClient.sessionTicketKeys) != 0 { + ret := configForClient.sessionTicketKeys + configForClient.mutex.RUnlock() + return ret + } + configForClient.mutex.RUnlock() + } + + c.mutex.RLock() + defer c.mutex.RUnlock() + if c.SessionTicketsDisabled { + return nil + } + c.initLegacySessionTicketKeyRLocked() + if len(c.sessionTicketKeys) != 0 { + return c.sessionTicketKeys + } + // Fast path for the common case where the key is fresh enough. + if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { + return c.autoSessionTicketKeys + } + + // autoSessionTicketKeys are managed by auto-rotation. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + // Re-check the condition in case it changed since obtaining the new lock. + if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { + var newKey [32]byte + if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { + panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) + } + valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) + valid = append(valid, c.ticketKeyFromBytes(newKey)) + for _, k := range c.autoSessionTicketKeys { + // While rotating the current key, also remove any expired ones. + if c.time().Sub(k.created) < ticketKeyLifetime { + valid = append(valid, k) + } + } + c.autoSessionTicketKeys = valid + } + return c.autoSessionTicketKeys +} + +// SetSessionTicketKeys updates the session ticket keys for a server. +// +// The first key will be used when creating new tickets, while all keys can be +// used for decrypting tickets. It is safe to call this function while the +// server is running in order to rotate the session ticket keys. The function +// will panic if keys is empty. +// +// Calling this function will turn off automatic session ticket key rotation. +// +// If multiple servers are terminating connections for the same host they should +// all have the same session ticket keys. If the session ticket keys leaks, +// previously recorded and future TLS connections using those keys might be +// compromised. +func (c *config) SetSessionTicketKeys(keys [][32]byte) { + if len(keys) == 0 { + panic("tls: keys must have at least one key") + } + + newKeys := make([]ticketKey, len(keys)) + for i, bytes := range keys { + newKeys[i] = c.ticketKeyFromBytes(bytes) + } + + c.mutex.Lock() + c.sessionTicketKeys = newKeys + c.mutex.Unlock() +} + +func (c *config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func (c *config) cipherSuites() []uint16 { + if c.CipherSuites != nil { + return c.CipherSuites + } + return defaultCipherSuites +} + +var supportedVersions = []uint16{ + VersionTLS13, + VersionTLS12, + VersionTLS11, + VersionTLS10, +} + +// debugEnableTLS10 enables TLS 1.0. See issue 45428. +// We don't care about TLS1.0 in qtls. Always disable it. +var debugEnableTLS10 = false + +// roleClient and roleServer are meant to call supportedVersions and parents +// with more readability at the callsite. +const roleClient = true +const roleServer = false + +func (c *config) supportedVersions(isClient bool) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if (c == nil || c.MinVersion == 0) && !debugEnableTLS10 && + isClient && v < VersionTLS12 { + continue + } + if c != nil && c.MinVersion != 0 && v < c.MinVersion { + continue + } + if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +func (c *config) maxSupportedVersion(isClient bool) uint16 { + supportedVersions := c.supportedVersions(isClient) + if len(supportedVersions) == 0 { + return 0 + } + return supportedVersions[0] +} + +// supportedVersionsFromMax returns a list of supported versions derived from a +// legacy maximum version value. Note that only versions supported by this +// library are returned. Any newer peer will use supportedVersions anyway. +func supportedVersionsFromMax(maxVersion uint16) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if v > maxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *config) curvePreferences() []CurveID { + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +func (c *config) supportsCurve(curve CurveID) bool { + for _, cc := range c.curvePreferences() { + if cc == curve { + return true + } + } + return false +} + +// mutualVersion returns the protocol version to use given the advertised +// versions of the peer. Priority is given to the peer preference order. +func (c *config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) { + supportedVersions := c.supportedVersions(isClient) + for _, peerVersion := range peerVersions { + for _, v := range supportedVersions { + if v == peerVersion { + return v, true + } + } + } + return 0, false +} + +var errNoCertificates = errors.New("tls: no certificates configured") + +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } + } + + if len(c.Certificates) == 0 { + return nil, errNoCertificates + } + + if len(c.Certificates) == 1 { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + if c.NameToCertificate != nil { + name := strings.ToLower(clientHello.ServerName) + if cert, ok := c.NameToCertificate[name]; ok { + return cert, nil + } + if len(name) > 0 { + labels := strings.Split(name, ".") + labels[0] = "*" + wildcardName := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[wildcardName]; ok { + return cert, nil + } + } + } + + for _, cert := range c.Certificates { + if err := clientHello.SupportsCertificate(&cert); err == nil { + return &cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0], nil +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the client that sent the ClientHello. Otherwise, it returns an error +// describing the reason for the incompatibility. +// +// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate +// callback, this method will take into account the associated Config. Note that +// if GetConfigForClient returns a different Config, the change can't be +// accounted for by this method. +// +// This function will call x509.ParseCertificate unless c.Leaf is set, which can +// incur a significant performance cost. +func (chi *clientHelloInfo) SupportsCertificate(c *Certificate) error { + // Note we don't currently support certificate_authorities nor + // signature_algorithms_cert, and don't check the algorithms of the + // signatures on the chain (which anyway are a SHOULD, see RFC 8446, + // Section 4.4.2.2). + + config := chi.config + if config == nil { + config = &Config{} + } + conf := fromConfig(config) + vers, ok := conf.mutualVersion(roleServer, chi.SupportedVersions) + if !ok { + return errors.New("no mutually supported protocol versions") + } + + // If the client specified the name they are trying to connect to, the + // certificate needs to be valid for it. + if chi.ServerName != "" { + x509Cert, err := leafCertificate(c) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { + return fmt.Errorf("certificate is not valid for requested server name: %w", err) + } + } + + // supportsRSAFallback returns nil if the certificate and connection support + // the static RSA key exchange, and unsupported otherwise. The logic for + // supporting static RSA is completely disjoint from the logic for + // supporting signed key exchanges, so we just check it as a fallback. + supportsRSAFallback := func(unsupported error) error { + // TLS 1.3 dropped support for the static RSA key exchange. + if vers == VersionTLS13 { + return unsupported + } + // The static RSA key exchange works by decrypting a challenge with the + // RSA private key, not by signing, so check the PrivateKey implements + // crypto.Decrypter, like *rsa.PrivateKey does. + if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { + if _, ok := priv.Public().(*rsa.PublicKey); !ok { + return unsupported + } + } else { + return unsupported + } + // Finally, there needs to be a mutual cipher suite that uses the static + // RSA key exchange instead of ECDHE. + rsaCipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + return false + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if rsaCipherSuite == nil { + return unsupported + } + return nil + } + + // If the client sent the signature_algorithms extension, ensure it supports + // schemes we can use with this certificate and TLS version. + if len(chi.SignatureSchemes) > 0 { + if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { + return supportsRSAFallback(err) + } + } + + // In TLS 1.3 we are done because supported_groups is only relevant to the + // ECDHE computation, point format negotiation is removed, cipher suites are + // only relevant to the AEAD choice, and static RSA does not exist. + if vers == VersionTLS13 { + return nil + } + + // The only signed key exchange we support is ECDHE. + if !supportsECDHE(conf, chi.SupportedCurves, chi.SupportedPoints) { + return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) + } + + var ecdsaCipherSuite bool + if priv, ok := c.PrivateKey.(crypto.Signer); ok { + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + var curve CurveID + switch pub.Curve { + case elliptic.P256(): + curve = CurveP256 + case elliptic.P384(): + curve = CurveP384 + case elliptic.P521(): + curve = CurveP521 + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + var curveOk bool + for _, c := range chi.SupportedCurves { + if c == curve && conf.supportsCurve(c) { + curveOk = true + break + } + } + if !curveOk { + return errors.New("client doesn't support certificate curve") + } + ecdsaCipherSuite = true + case ed25519.PublicKey: + if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { + return errors.New("connection doesn't support Ed25519") + } + ecdsaCipherSuite = true + case *rsa.PublicKey: + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + } else { + return supportsRSAFallback(unsupportedCertificateError(c)) + } + + // Make sure that there is a mutually supported cipher suite that works with + // this certificate. Cipher suite selection will then apply the logic in + // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. + cipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE == 0 { + return false + } + if c.flags&suiteECSign != 0 { + if !ecdsaCipherSuite { + return false + } + } else { + if ecdsaCipherSuite { + return false + } + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if cipherSuite == nil { + return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) + } + + return nil +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +// +// Deprecated: NameToCertificate only allows associating a single certificate +// with a given name. Leave that field nil to let the library select the first +// compatible chain from Certificates. +func (c *config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := leafCertificate(cert) + if err != nil { + continue + } + // If SANs are *not* present, some clients will consider the certificate + // valid for the name in the Common Name. + if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +const ( + keyLogLabelTLS12 = "CLIENT_RANDOM" + keyLogLabelEarlyTraffic = "CLIENT_EARLY_TRAFFIC_SECRET" + keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" + keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" +) + +func (c *config) writeKeyLog(label string, clientRandom, secret []byte) error { + if c.KeyLogWriter == nil { + return nil + } + + logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) + + writerMutex.Lock() + _, err := c.KeyLogWriter.Write(logLine) + writerMutex.Unlock() + + return err +} + +// writerMutex protects all KeyLogWriters globally. It is rarely enabled, +// and is only for debugging, so a global mutex saves space. +var writerMutex sync.Mutex + +// A Certificate is a chain of one or more certificates, leaf first. +type Certificate = tls.Certificate + +// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing +// the corresponding c.Certificate[0]. +func leafCertificate(c *Certificate) (*x509.Certificate, error) { + if c.Leaf != nil { + return c.Leaf, nil + } + return x509.ParseCertificate(c.Certificate[0]) +} + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) bool +} + +// lruSessionCache is a ClientSessionCache implementation that uses an LRU +// caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state *ClientSessionState +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + } +} + +// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry +// corresponding to sessionKey is removed from the cache instead. +func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + if cs == nil { + c.q.Remove(elem) + delete(c.m, sessionKey) + } else { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + } + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the ClientSessionState value associated with a given key. It +// returns (nil, false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +func unexpectedMessageError(wanted, got any) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { + for _, s := range supportedSignatureAlgorithms { + if s == sigAlg { + return true + } + } + return false +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/conn.go b/vendor/github.com/marten-seemann/qtls-go1-18/conn.go new file mode 100644 index 00000000..90a27b5d --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/conn.go @@ -0,0 +1,1608 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package qtls + +import ( + "bytes" + "context" + "crypto/cipher" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn net.Conn + isClient bool + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake + + // handshakeStatus is 1 if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). + // handshakeStatus == 1 implies handshakeErr == nil. + // This field is only to be accessed with sync/atomic. + handshakeStatus uint32 + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *config // configuration passed to constructor + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + extraConfig *ExtraConfig + + handshakes int + didResume bool // whether this connection was a session resumption + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // signed certificate timestamps from server + peerCertificates []*x509.Certificate + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // serverName contains the server name indicated by the client, if any. + serverName string + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + // ekm is a closure for exporting keying material. + ekm func(label string, context []byte, length int) ([]byte, error) + // For the client: + // resumptionSecret is the resumption_master_secret for handling + // NewSessionTicket messages. nil if config.SessionTicketsDisabled. + // For the server: + // resumptionSecret is the resumption_master_secret for generating + // NewSessionTicket messages. Only used when the alternative record + // layer is set. nil if config.SessionTicketsDisabled. + resumptionSecret []byte + + // ticketKeys is the set of active session ticket keys for this + // connection. The first one is used to encrypt new tickets and + // all are tried to decrypt tickets. + ticketKeys []ticketKey + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte + + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string + + // input/output + in, out halfConn + rawInput bytes.Buffer // raw input, starting with a record header + input bytes.Reader // application data waiting to be read, from rawInput.Next + hand bytes.Buffer // handshake data waiting to be read + buffering bool // whether records are buffered in sendBuf + sendBuf []byte // a buffer of records waiting to be sent + + // bytesSent counts the bytes of application data sent. + // packetsSent counts packets. + bytesSent int64 + packetsSent int64 + + // retryCount counts the number of consecutive non-advancing records + // received by Conn.readRecord. That is, records that neither advance the + // handshake, nor deliver application data. Protected by in.Mutex. + retryCount int + + // activeCall is an atomic int32; the low bit is whether Close has + // been called. the rest of the bits are the number of goroutines + // in Conn.Write. + activeCall int32 + + used0RTT bool + + tmp [16]byte +} + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// TLS session. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + + err error // first permanent error + version uint16 // protocol version + cipher any // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape + + nextCipher any // next encryption state + nextMac hash.Hash // next MAC algorithm + + trafficSecret []byte // current TLS 1.3 traffic secret + + setKeyCallback func(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) +} + +type permanentError struct { + err net.Error +} + +func (e *permanentError) Error() string { return e.err.Error() } +func (e *permanentError) Unwrap() error { return e.err } +func (e *permanentError) Timeout() bool { return e.err.Timeout() } +func (e *permanentError) Temporary() bool { return false } + +func (hc *halfConn) setErrorLocked(err error) error { + if e, ok := err.(net.Error); ok { + hc.err = &permanentError{err: e} + } else { + hc.err = err + } + return hc.err +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) { + hc.version = version + hc.nextCipher = cipher + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec() error { + if hc.nextCipher == nil || hc.version == VersionTLS13 { + return alertInternalError + } + hc.cipher = hc.nextCipher + hc.mac = hc.nextMac + hc.nextCipher = nil + hc.nextMac = nil + for i := range hc.seq { + hc.seq[i] = 0 + } + return nil +} + +func (hc *halfConn) exportKey(encLevel EncryptionLevel, suite *cipherSuiteTLS13, trafficSecret []byte) { + if hc.setKeyCallback != nil { + s := &CipherSuiteTLS13{ + ID: suite.id, + KeyLen: suite.keyLen, + Hash: suite.hash, + AEAD: func(key, fixedNonce []byte) cipher.AEAD { return suite.aead(key, fixedNonce) }, + } + hc.setKeyCallback(encLevel, s, trafficSecret) + } +} + +func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { + hc.trafficSecret = secret + key, iv := suite.trafficKey(secret) + hc.cipher = suite.aead(key, iv) + for i := range hc.seq { + hc.seq[i] = 0 + } +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq() { + for i := 7; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +// explicitNonceLen returns the number of bytes of explicit nonce or IV included +// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 +// and in certain AEAD modes in TLS 1.2. +func (hc *halfConn) explicitNonceLen() int { + if hc.cipher == nil { + return 0 + } + + switch c := hc.cipher.(type) { + case cipher.Stream: + return 0 + case aead: + return c.explicitNonceLen() + case cbcMode: + // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. + if hc.version >= VersionTLS11 { + return c.BlockSize() + } + return 0 + default: + panic("unknown cipher type") + } +} + +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. +func extractPadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) + + // The maximum possible padding length plus the actual length field + toCheck := 256 + // The length of the padded data is public, so we can use an if here + if toCheck > len(payload) { + toCheck = len(payload) + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + // Zero the padding length on error. This ensures any unchecked bytes + // are included in the MAC. Otherwise, an attacker that could + // distinguish MAC failures from padding failures could mount an attack + // similar to POODLE in SSL 3.0: given a good ciphertext that uses a + // full block's worth of padding, replace the final block with another + // block. If the MAC check passed but the padding check failed, the + // last byte of that block decrypted to the block size. + // + // See also macAndPaddingGood logic below. + paddingLen &= good + + toRemove = int(paddingLen) + 1 + return +} + +func roundUp(a, b int) int { + return a + (b-a%b)%b +} + +// cbcMode is an interface for block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// decrypt authenticates and decrypts the record if protection is active at +// this stage. The returned plaintext might overlap with the input. +func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { + var plaintext []byte + typ := recordType(record[0]) + payload := record[recordHeaderLen:] + + // In TLS 1.3, change_cipher_spec messages are to be ignored without being + // decrypted. See RFC 8446, Appendix D.4. + if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { + return payload, typ, nil + } + + paddingGood := byte(255) + paddingLen := 0 + + explicitNonceLen := hc.explicitNonceLen() + + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case aead: + if len(payload) < explicitNonceLen { + return nil, 0, alertBadRecordMAC + } + nonce := payload[:explicitNonceLen] + if len(nonce) == 0 { + nonce = hc.seq[:] + } + payload = payload[explicitNonceLen:] + + var additionalData []byte + if hc.version == VersionTLS13 { + additionalData = record[:recordHeaderLen] + } else { + additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:3]...) + n := len(payload) - c.Overhead() + additionalData = append(additionalData, byte(n>>8), byte(n)) + } + + var err error + plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) + if err != nil { + return nil, 0, alertBadRecordMAC + } + case cbcMode: + blockSize := c.BlockSize() + minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) + if len(payload)%blockSize != 0 || len(payload) < minPayload { + return nil, 0, alertBadRecordMAC + } + + if explicitNonceLen > 0 { + c.SetIV(payload[:explicitNonceLen]) + payload = payload[explicitNonceLen:] + } + c.CryptBlocks(payload, payload) + + // In a limited attempt to protect against CBC padding oracles like + // Lucky13, the data past paddingLen (which is secret) is passed to + // the MAC function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC roughly constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write, modulo cache effects. + paddingLen, paddingGood = extractPadding(payload) + default: + panic("unknown cipher type") + } + + if hc.version == VersionTLS13 { + if typ != recordTypeApplicationData { + return nil, 0, alertUnexpectedMessage + } + if len(plaintext) > maxPlaintext+1 { + return nil, 0, alertRecordOverflow + } + // Remove padding and find the ContentType scanning from the end. + for i := len(plaintext) - 1; i >= 0; i-- { + if plaintext[i] != 0 { + typ = recordType(plaintext[i]) + plaintext = plaintext[:i] + break + } + if i == 0 { + return nil, 0, alertUnexpectedMessage + } + } + } + } else { + plaintext = payload + } + + if hc.mac != nil { + macSize := hc.mac.Size() + if len(payload) < macSize { + return nil, 0, alertBadRecordMAC + } + + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } + record[3] = byte(n >> 8) + record[4] = byte(n) + remoteMAC := payload[n : n+macSize] + localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + + // This is equivalent to checking the MACs and paddingGood + // separately, but in constant-time to prevent distinguishing + // padding failures from MAC failures. Depending on what value + // of paddingLen was returned on bad padding, distinguishing + // bad MAC from bad padding can lead to an attack. + // + // See also the logic at the end of extractPadding. + macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) + if macAndPaddingGood != 1 { + return nil, 0, alertBadRecordMAC + } + + plaintext = payload[:n] + } + + hc.incSeq() + return plaintext, typ, nil +} + +func (c *Conn) setAlternativeRecordLayer() { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + c.in.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetReadKey + c.out.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetWriteKey + } +} + +// sliceForAppend extends the input slice by n bytes. head is the full extended +// slice, while tail is the appended part. If the original slice has sufficient +// capacity no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and +// appends it to record, which must already contain the record header. +func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { + if hc.cipher == nil { + return append(record, payload...), nil + } + + var explicitNonce []byte + if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { + record, explicitNonce = sliceForAppend(record, explicitNonceLen) + if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { + // The AES-GCM construction in TLS has an explicit nonce so that the + // nonce can be random. However, the nonce is only 8 bytes which is + // too small for a secure, random nonce. Therefore we use the + // sequence number as the nonce. The 3DES-CBC construction also has + // an 8 bytes nonce but its nonces must be unpredictable (see RFC + // 5246, Appendix F.3), forcing us to use randomness. That's not + // 3DES' biggest problem anyway because the birthday bound on block + // collision is reached first due to its similarly small block size + // (see the Sweet32 attack). + copy(explicitNonce, hc.seq[:]) + } else { + if _, err := io.ReadFull(rand, explicitNonce); err != nil { + return nil, err + } + } + } + + var dst []byte + switch c := hc.cipher.(type) { + case cipher.Stream: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + record, dst = sliceForAppend(record, len(payload)+len(mac)) + c.XORKeyStream(dst[:len(payload)], payload) + c.XORKeyStream(dst[len(payload):], mac) + case aead: + nonce := explicitNonce + if len(nonce) == 0 { + nonce = hc.seq[:] + } + + if hc.version == VersionTLS13 { + record = append(record, payload...) + + // Encrypt the actual ContentType and replace the plaintext one. + record = append(record, record[0]) + record[0] = byte(recordTypeApplicationData) + + n := len(payload) + 1 + c.Overhead() + record[3] = byte(n >> 8) + record[4] = byte(n) + + record = c.Seal(record[:recordHeaderLen], + nonce, record[recordHeaderLen:], record[:recordHeaderLen]) + } else { + additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:recordHeaderLen]...) + record = c.Seal(record, nonce, payload, additionalData) + } + case cbcMode: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + blockSize := c.BlockSize() + plaintextLen := len(payload) + len(mac) + paddingLen := blockSize - plaintextLen%blockSize + record, dst = sliceForAppend(record, plaintextLen+paddingLen) + copy(dst, payload) + copy(dst[len(payload):], mac) + for i := plaintextLen; i < len(dst); i++ { + dst[i] = byte(paddingLen - 1) + } + if len(explicitNonce) > 0 { + c.SetIV(explicitNonce) + } + c.CryptBlocks(dst, dst) + default: + panic("unknown cipher type") + } + + // Update length to include nonce, MAC and any block padding needed. + n := len(record) - recordHeaderLen + record[3] = byte(n >> 8) + record[4] = byte(n) + hc.incSeq() + + return record, nil +} + +// RecordHeaderError is returned when a TLS record header is invalid. +type RecordHeaderError struct { + // Msg contains a human readable string that describes the error. + Msg string + // RecordHeader contains the five bytes of TLS record header that + // triggered the error. + RecordHeader [5]byte + // Conn provides the underlying net.Conn in the case that a client + // sent an initial handshake that didn't look like TLS. + // It is nil if there's already been a handshake or a TLS alert has + // been written to the connection. + Conn net.Conn +} + +func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } + +func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { + err.Msg = msg + err.Conn = conn + copy(err.RecordHeader[:], c.rawInput.Bytes()) + return err +} + +func (c *Conn) readRecord() error { + return c.readRecordOrCCS(false) +} + +func (c *Conn) readChangeCipherSpec() error { + return c.readRecordOrCCS(true) +} + +// readRecordOrCCS reads one or more TLS records from the connection and +// updates the record layer state. Some invariants: +// * c.in must be locked +// * c.input must be empty +// During the handshake one and only one of the following will happen: +// - c.hand grows +// - c.in.changeCipherSpec is called +// - an error is returned +// After the handshake one and only one of the following will happen: +// - c.hand grows +// - c.input is set +// - an error is returned +func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { + if c.in.err != nil { + return c.in.err + } + handshakeComplete := c.handshakeComplete() + + // This function modifies c.rawInput, which owns the c.input memory. + if c.input.Len() != 0 { + return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data")) + } + c.input.Reset(nil) + + // Read header, payload. + if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { + // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify + // is an error, but popular web sites seem to do this, so we accept it + // if and only if at the record boundary. + if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 { + err = io.EOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + hdr := c.rawInput.Bytes()[:recordHeaderLen] + typ := recordType(hdr[0]) + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + if !handshakeComplete && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) + } + + vers := uint16(hdr[1])<<8 | uint16(hdr[2]) + n := int(hdr[3])<<8 | int(hdr[4]) + if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { + c.sendAlert(alertProtocolVersion) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if !c.haveVers { + // First message, be extra suspicious: this might not be a TLS + // client. Bail out before reading a full 'body', if possible. + // The current max version is 3.3 so if the version is >= 16.0, + // it's probably not real. + if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { + return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) + } + } + if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + + // Process message. + record := c.rawInput.Next(recordHeaderLen + n) + data, typ, err := c.in.decrypt(record) + if err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + if len(data) > maxPlaintext { + return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) + } + + // Application Data messages are always protected. + if c.in.cipher == nil && typ == recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { + // This is a state-advancing message: reset the retry count. + c.retryCount = 0 + } + + // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. + if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + switch typ { + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + + case recordTypeAlert: + if len(data) != 2 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if alert(data[1]) == alertCloseNotify { + return c.in.setErrorLocked(io.EOF) + } + if c.vers == VersionTLS13 { + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + } + switch data[0] { + case alertLevelWarning: + // Drop the record on the floor and retry. + return c.retryReadRecord(expectChangeCipherSpec) + case alertLevelError: + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if len(data) != 1 || data[0] != 1 { + return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) + } + // Handshake messages are not allowed to fragment across the CCS. + if c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // In TLS 1.3, change_cipher_spec records are ignored until the + // Finished. See RFC 8446, Appendix D.4. Note that according to Section + // 5, a server can send a ChangeCipherSpec before its ServerHello, when + // c.vers is still unset. That's not useful though and suspicious if the + // server then selects a lower protocol version, so don't allow that. + if c.vers == VersionTLS13 { + return c.retryReadRecord(expectChangeCipherSpec) + } + if !expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if err := c.in.changeCipherSpec(); err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + + case recordTypeApplicationData: + if !handshakeComplete || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // Some OpenSSL servers send empty records in order to randomize the + // CBC IV. Ignore a limited number of empty records. + if len(data) == 0 { + return c.retryReadRecord(expectChangeCipherSpec) + } + // Note that data is owned by c.rawInput, following the Next call above, + // to avoid copying the plaintext. This is safe because c.rawInput is + // not read from or written to until c.input is drained. + c.input.Reset(data) + + case recordTypeHandshake: + if len(data) == 0 || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.hand.Write(data) + } + + return nil +} + +// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like +// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. +func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many ignored records")) + } + return c.readRecordOrCCS(expectChangeCipherSpec) +} + +// atLeastReader reads from R, stopping with EOF once at least N bytes have been +// read. It is different from an io.LimitedReader in that it doesn't cut short +// the last Read call, and in that it considers an early EOF an error. +type atLeastReader struct { + R io.Reader + N int64 +} + +func (r *atLeastReader) Read(p []byte) (int, error) { + if r.N <= 0 { + return 0, io.EOF + } + n, err := r.R.Read(p) + r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 + if r.N > 0 && err == io.EOF { + return n, io.ErrUnexpectedEOF + } + if r.N <= 0 && err == nil { + return n, io.EOF + } + return n, err +} + +// readFromUntil reads from r into c.rawInput until c.rawInput contains +// at least n bytes or else returns an error. +func (c *Conn) readFromUntil(r io.Reader, n int) error { + if c.rawInput.Len() >= n { + return nil + } + needs := n - c.rawInput.Len() + // There might be extra input waiting on the wire. Make a best effort + // attempt to fetch it so that it can be used in (*Conn).Read to + // "predict" closeNotify alerts. + c.rawInput.Grow(needs + bytes.MinRead) + _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)}) + return err +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlertLocked(err alert) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr + } + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlert(err alert) error { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + c.extraConfig.AlternativeRecordLayer.SendAlert(uint8(err)) + return &net.OpError{Op: "local error", Err: err} + } + + c.out.Lock() + defer c.out.Unlock() + return c.sendAlertLocked(err) +} + +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the +// next application data record. There is the following trade-off: +// +// - For latency-sensitive applications, such as web browsing, each TLS +// record should fit in one TCP segment. +// - For throughput-sensitive applications, such as large file transfers, +// larger TLS records better amortize framing and encryption overheads. +// +// A simple heuristic that works well in practice is to use small records for +// the first 1MB of data, then use larger records for subsequent data, and +// reset back to smaller records after the connection becomes idle. See "High +// Performance Web Networking", Chapter 4, or: +// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ +// +// In the interests of simplicity and determinism, this code does not attempt +// to reset the record size once the connection is idle, however. +func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { + if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { + return maxPlaintext + } + + if c.bytesSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() + if c.out.cipher != nil { + switch ciph := c.out.cipher.(type) { + case cipher.Stream: + payloadBytes -= c.out.mac.Size() + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The MAC is appended before padding so affects the + // payload size directly. + payloadBytes -= c.out.mac.Size() + default: + panic("unknown cipher type") + } + } + if c.vers == VersionTLS13 { + payloadBytes-- // encrypted ContentType + } + + // Allow packet growth in arithmetic progression up to max. + pkt := c.packetsSent + c.packetsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} + +func (c *Conn) write(data []byte) (int, error) { + if c.buffering { + c.sendBuf = append(c.sendBuf, data...) + return len(data), nil + } + + n, err := c.conn.Write(data) + c.bytesSent += int64(n) + return n, err +} + +func (c *Conn) flush() (int, error) { + if len(c.sendBuf) == 0 { + return 0, nil + } + + n, err := c.conn.Write(c.sendBuf) + c.bytesSent += int64(n) + c.sendBuf = nil + c.buffering = false + return n, err +} + +// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. +var outBufPool = sync.Pool{ + New: func() any { + return new([]byte) + }, +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + outBufPtr := outBufPool.Get().(*[]byte) + outBuf := *outBufPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *outBufPtr = outBuf + outBufPool.Put(outBufPtr) + }() + + var n int + for len(data) > 0 { + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { + m = maxPayload + } + + _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } else if vers == VersionTLS13 { + // TLS 1.3 froze the record layer version to 1.2. + // See RFC 8446, Section 5.1. + vers = VersionTLS12 + } + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) + + var err error + outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) + if err != nil { + return n, err + } + if _, err := c.write(outBuf); err != nil { + return n, err + } + n += m + data = data[m:] + } + + if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) + } + } + + return n, nil +} + +// writeRecord writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + if typ == recordTypeChangeCipherSpec { + return len(data), nil + } + return c.extraConfig.AlternativeRecordLayer.WriteRecord(data) + } + + c.out.Lock() + defer c.out.Unlock() + + return c.writeRecordLocked(typ, data) +} + +// readHandshake reads the next handshake message from +// the record layer. +func (c *Conn) readHandshake() (any, error) { + var data []byte + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + var err error + data, err = c.extraConfig.AlternativeRecordLayer.ReadHandshakeMessage() + if err != nil { + return nil, err + } + } else { + for c.hand.Len() < 4 { + if err := c.readRecord(); err != nil { + return nil, err + } + } + + data = c.hand.Bytes() + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + c.sendAlertLocked(alertInternalError) + return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) + } + for c.hand.Len() < 4+n { + if err := c.readRecord(); err != nil { + return nil, err + } + } + data = c.hand.Next(4 + n) + } + var m handshakeMessage + switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) + case typeClientHello: + m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeNewSessionTicket: + if c.vers == VersionTLS13 { + m = new(newSessionTicketMsgTLS13) + } else { + m = new(newSessionTicketMsg) + } + case typeCertificate: + if c.vers == VersionTLS13 { + m = new(certificateMsgTLS13) + } else { + m = new(certificateMsg) + } + case typeCertificateRequest: + if c.vers == VersionTLS13 { + m = new(certificateRequestMsgTLS13) + } else { + m = &certificateRequestMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + } + case typeCertificateStatus: + m = new(certificateStatusMsg) + case typeServerKeyExchange: + m = new(serverKeyExchangeMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeCertificateVerify: + m = &certificateVerifyMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + case typeFinished: + m = new(finishedMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) + case typeKeyUpdate: + m = new(keyUpdateMsg) + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshalers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if !m.unmarshal(data) { + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + return m, nil +} + +var ( + errShutdown = errors.New("tls: protocol is shutdown") +) + +// Write writes data to the connection. +// +// As Write calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Write is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Write(b []byte) (int, error) { + // interlock with Close below + for { + x := atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return 0, net.ErrClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { + break + } + } + defer atomic.AddInt32(&c.activeCall, -2) + + if err := c.Handshake(); err != nil { + return 0, err + } + + c.out.Lock() + defer c.out.Unlock() + + if err := c.out.err; err != nil { + return 0, err + } + + if !c.handshakeComplete() { + return 0, alertInternalError + } + + if c.closeNotifySent { + return 0, errShutdown + } + + // TLS 1.0 is susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the IV. + // + // https://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // https://www.imperialviolet.org/2012/01/15/beastfollowup.html + + var m int + if len(b) > 1 && c.vers == VersionTLS10 { + if _, ok := c.out.cipher.(cipher.BlockMode); ok { + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.out.setErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecordLocked(recordTypeApplicationData, b) + return n + m, c.out.setErrorLocked(err) +} + +// handleRenegotiation processes a HelloRequest handshake message. +func (c *Conn) handleRenegotiation() error { + if c.vers == VersionTLS13 { + return errors.New("tls: internal error: unexpected renegotiation") + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + helloReq, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(helloReq, msg) + } + + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + atomic.StoreUint32(&c.handshakeStatus, 0) + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + +func (c *Conn) HandlePostHandshakeMessage() error { + return c.handlePostHandshakeMessage() +} + +// handlePostHandshakeMessage processes a handshake message arrived after the +// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. +func (c *Conn) handlePostHandshakeMessage() error { + if c.vers != VersionTLS13 { + return c.handleRenegotiation() + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) + } + + switch msg := msg.(type) { + case *newSessionTicketMsgTLS13: + return c.handleNewSessionTicket(msg) + case *keyUpdateMsg: + return c.handleKeyUpdate(msg) + default: + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) + } +} + +func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil { + return c.in.setErrorLocked(c.sendAlert(alertInternalError)) + } + + newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) + c.in.setTrafficSecret(cipherSuite, newSecret) + + if keyUpdate.updateRequested { + c.out.Lock() + defer c.out.Unlock() + + msg := &keyUpdateMsg{} + _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + if err != nil { + // Surface the error at the next write. + c.out.setErrorLocked(err) + return nil + } + + newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) + c.out.setTrafficSecret(cipherSuite, newSecret) + } + + return nil +} + +// Read reads data from the connection. +// +// As Read calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Read is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Read(b []byte) (int, error) { + if err := c.Handshake(); err != nil { + return 0, err + } + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return 0, nil + } + + c.in.Lock() + defer c.in.Unlock() + + for c.input.Len() == 0 { + if err := c.readRecord(); err != nil { + return 0, err + } + for c.hand.Len() > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + return 0, err + } + } + } + + n, _ := c.input.Read(b) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && + recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { + if err := c.readRecord(); err != nil { + return n, err // will be io.EOF on closeNotify + } + } + + return n, nil +} + +// Close closes the connection. +func (c *Conn) Close() error { + // Interlock with Conn.Write above. + var x int32 + for { + x = atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return net.ErrClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { + break + } + } + if x != 0 { + // io.Writer and io.Closer should not be used concurrently. + // If Close is called while a Write is currently in-flight, + // interpret that as a sign that this Close is really just + // being used to break the Write and/or clean up resources and + // avoid sending the alertCloseNotify, which may block + // waiting on handshakeMutex or the c.out mutex. + return c.conn.Close() + } + + var alertErr error + if c.handshakeComplete() { + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } + } + + if err := c.conn.Close(); err != nil { + return err + } + return alertErr +} + +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + if !c.handshakeComplete() { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + c.out.Lock() + defer c.out.Unlock() + + if !c.closeNotifySent { + // Set a Write Deadline to prevent possibly blocking forever. + c.SetWriteDeadline(time.Now().Add(time.Second * 5)) + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + // Any subsequent writes will fail. + c.SetWriteDeadline(time.Now()) + } + return c.closeNotifyErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first Read or Write will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// HandshakeContext or the Dialer's DialContext method instead. +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + // Fast sync/atomic-based exit if there is no handshake in flight and the + // last one succeeded without an error. Avoids the expensive context setup + // and mutex for most Read and Write calls. + if c.handshakeComplete() { + return nil + } + + handshakeCtx, cancel := context.WithCancel(ctx) + // Note: defer this before starting the "interrupter" goroutine + // so that we can tell the difference between the input being canceled and + // this cancellation. In the former case, we need to close the connection. + defer cancel() + + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // Return context error to user. + ret = ctxErr + } + }() + go func() { + select { + case <-handshakeCtx.Done(): + // Close the connection, discarding the error + _ = c.conn.Close() + interruptRes <- handshakeCtx.Err() + case <-done: + interruptRes <- nil + } + }() + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeComplete() { + return nil + } + + c.in.Lock() + defer c.in.Unlock() + + c.handshakeErr = c.handshakeFn(handshakeCtx) + if c.handshakeErr == nil { + c.handshakes++ + } else { + // If an error occurred during the handshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + if c.handshakeErr == nil && !c.handshakeComplete() { + c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") + } + if c.handshakeErr != nil && c.handshakeComplete() { + panic("tls: internal error: handshake returned an error but is marked successful") + } + + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return c.connectionStateLocked() +} + +// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection. +func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return ConnectionStateWith0RTT{ + ConnectionState: c.connectionStateLocked(), + Used0RTT: c.used0RTT, + } +} + +func (c *Conn) connectionStateLocked() ConnectionState { + var state connectionState + state.HandshakeComplete = c.handshakeComplete() + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = true + state.ServerName = c.serverName + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } + } + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } + return toConnectionState(state) +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) VerifyHostname(host string) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.isClient { + return errors.New("tls: VerifyHostname called on TLS server connection") + } + if !c.handshakeComplete() { + return errors.New("tls: handshake has not yet been performed") + } + if len(c.verifiedChains) == 0 { + return errors.New("tls: handshake did not verify certificate chain") + } + return c.peerCertificates[0].VerifyHostname(host) +} + +func (c *Conn) handshakeComplete() bool { + return atomic.LoadUint32(&c.handshakeStatus) == 1 +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/cpu.go b/vendor/github.com/marten-seemann/qtls-go1-18/cpu.go new file mode 100644 index 00000000..12194508 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/cpu.go @@ -0,0 +1,22 @@ +//go:build !js +// +build !js + +package qtls + +import ( + "runtime" + + "golang.org/x/sys/cpu" +) + +var ( + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && + (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || + runtime.GOARCH == "arm64" && hasGCMAsmARM64 || + runtime.GOARCH == "s390x" && hasGCMAsmS390X +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/cpu_other.go b/vendor/github.com/marten-seemann/qtls-go1-18/cpu_other.go new file mode 100644 index 00000000..33f7d219 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/cpu_other.go @@ -0,0 +1,12 @@ +//go:build js +// +build js + +package qtls + +var ( + hasGCMAsmAMD64 = false + hasGCMAsmARM64 = false + hasGCMAsmS390X = false + + hasAESGCMHardwareSupport = false +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client.go b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client.go new file mode 100644 index 00000000..ab691d56 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client.go @@ -0,0 +1,1111 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "strings" + "sync/atomic" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +const clientSessionStateVersion = 1 + +type clientHandshakeState struct { + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + finishedHash finishedHash + masterSecret []byte + session *clientSessionState +} + +func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { + config := c.config + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + + nextProtosLength := 0 + for _, proto := range config.NextProtos { + if l := len(proto); l == 0 || l > 255 { + return nil, nil, errors.New("tls: invalid NextProtos value") + } else { + nextProtosLength += 1 + l + } + } + if nextProtosLength > 0xffff { + return nil, nil, errors.New("tls: NextProtos values too large") + } + + var supportedVersions []uint16 + var clientHelloVersion uint16 + if c.extraConfig.usesAlternativeRecordLayer() { + if config.maxSupportedVersion(roleClient) < VersionTLS13 { + return nil, nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") + } + // Only offer TLS 1.3 when QUIC is used. + supportedVersions = []uint16{VersionTLS13} + clientHelloVersion = VersionTLS13 + } else { + supportedVersions = config.supportedVersions(roleClient) + if len(supportedVersions) == 0 { + return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") + } + clientHelloVersion = config.maxSupportedVersion(roleClient) + } + + // The version at the beginning of the ClientHello was capped at TLS 1.2 + // for compatibility reasons. The supported_versions extension is used + // to negotiate versions now. See RFC 8446, Section 4.2.1. + if clientHelloVersion > VersionTLS12 { + clientHelloVersion = VersionTLS12 + } + + hello := &clientHelloMsg{ + vers: clientHelloVersion, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + ocspStapling: true, + scts: true, + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + secureRenegotiationSupported: true, + alpnProtocols: config.NextProtos, + supportedVersions: supportedVersions, + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } + + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport { + preferenceOrder = cipherSuitesPreferenceOrderNoAES + } + configCipherSuites := config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(configCipherSuites)) + + for _, suiteId := range preferenceOrder { + suite := mutualCipherSuite(configCipherSuites, suiteId) + if suite == nil { + continue + } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + continue + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) + } + + _, err := io.ReadFull(config.rand(), hello.random) + if err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + // A random session ID is used to detect when the server accepted a ticket + // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as + // a compatibility measure (see RFC 8446, Section 4.1.2). + if c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + hello.sessionId = make([]byte, 32) + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + } + + if hello.vers >= VersionTLS12 { + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + var params ecdheParameters + if hello.supportedVersions[0] == VersionTLS13 { + var suites []uint16 + for _, suiteID := range configCipherSuites { + for _, suite := range cipherSuitesTLS13 { + if suite.id == suiteID { + suites = append(suites, suiteID) + } + } + } + if len(suites) > 0 { + hello.cipherSuites = suites + } else { + if hasAESGCMHardwareSupport { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) + } else { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) + } + } + + curveID := config.curvePreferences()[0] + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err = generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + if hello.supportedVersions[0] == VersionTLS13 && c.extraConfig != nil && c.extraConfig.GetExtensions != nil { + hello.additionalExtensions = c.extraConfig.GetExtensions(typeClientHello) + } + + return hello, params, nil +} + +func (c *Conn) clientHandshake(ctx context.Context) (err error) { + if c.config == nil { + c.config = fromConfig(defaultConfig()) + } + c.setAlternativeRecordLayer() + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, ecdheParams, err := c.makeClientHello() + if err != nil { + return err + } + c.serverName = hello.serverName + + cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + if cacheKey != "" && session != nil { + var deletedTicket bool + if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT { + // don't reuse a session ticket that enabled 0-RTT + c.config.ClientSessionCache.Put(cacheKey, nil) + deletedTicket = true + + if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil { + h := suite.hash.New() + h.Write(hello.marshal()) + clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h) + c.out.exportKey(Encryption0RTT, suite, clientEarlySecret) + if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + } + } + if !deletedTicket { + defer func() { + // If we got a handshake failure when resuming a session, throw away + // the session ticket. See RFC 5077, Section 3.2. + // + // RFC 8446 makes no mention of dropping tickets on failure, but it + // does require servers to abort on invalid binders, so we need to + // delete tickets to recover from a corrupted PSK. + if err != nil { + c.config.ClientSessionCache.Put(cacheKey, nil) + } + }() + } + } + + if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + + if err := c.pickTLSVersion(serverHello); err != nil { + return err + } + + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion(roleClient) + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + + if c.vers == VersionTLS13 { + hs := &clientHandshakeStateTLS13{ + c: c, + ctx: ctx, + serverHello: serverHello, + hello: hello, + ecdheParams: ecdheParams, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + } + + // In TLS 1.3, session tickets are delivered after the handshake. + return hs.handshake() + } + + hs := &clientHandshakeState{ + c: c, + ctx: ctx, + serverHello: serverHello, + hello: hello, + session: session, + } + + if err := hs.handshake(); err != nil { + return err + } + + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one. + if cacheKey != "" && hs.session != nil && session != hs.session { + c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session)) + } + + return nil +} + +// extract the app data saved in the session.nonce, +// and set the session.nonce to the actual nonce value +func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max early data */, []byte /* app data */, bool /* ok */) { + s := cryptobyte.String(session.nonce) + var version uint16 + if !s.ReadUint16(&version) { + return 0, nil, false + } + if version != clientSessionStateVersion { + return 0, nil, false + } + var maxEarlyData uint32 + if !s.ReadUint32(&maxEarlyData) { + return 0, nil, false + } + var appData []byte + if !readUint16LengthPrefixed(&s, &appData) { + return 0, nil, false + } + var nonce []byte + if !readUint16LengthPrefixed(&s, &nonce) { + return 0, nil, false + } + session.nonce = nonce + return maxEarlyData, appData, true +} + +func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, + session *clientSessionState, earlySecret, binderKey []byte) { + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return "", nil, nil, nil + } + + hello.ticketSupported = true + + if hello.supportedVersions[0] == VersionTLS13 { + // Require DHE on resumption as it guarantees forward secrecy against + // compromise of the session ticket key. See RFC 8446, Section 4.2.9. + hello.pskModes = []uint8{pskModeDHE} + } + + // Session resumption is not allowed if renegotiating because + // renegotiation is primarily used to allow a client to send a client + // certificate, which would be skipped if session resumption occurred. + if c.handshakes != 0 { + return "", nil, nil, nil + } + + // Try to resume a previously negotiated TLS session, if available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + sess, ok := c.config.ClientSessionCache.Get(cacheKey) + if !ok || sess == nil { + return cacheKey, nil, nil, nil + } + session = fromClientSessionState(sess) + + var appData []byte + var maxEarlyData uint32 + if session.vers == VersionTLS13 { + var ok bool + maxEarlyData, appData, ok = c.decodeSessionState(session) + if !ok { // delete it, if parsing failed + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + } + + // Check that version used for the previous session is still valid. + versOk := false + for _, v := range hello.supportedVersions { + if v == session.vers { + versOk = true + break + } + } + if !versOk { + return cacheKey, nil, nil, nil + } + + // Check that the cached server certificate is not expired, and that it's + // valid for the ServerName. This should be ensured by the cache key, but + // protect the application from a faulty ClientSessionCache implementation. + if !c.config.InsecureSkipVerify { + if len(session.verifiedChains) == 0 { + // The original connection had InsecureSkipVerify, while this doesn't. + return cacheKey, nil, nil, nil + } + serverCert := session.serverCertificates[0] + if c.config.time().After(serverCert.NotAfter) { + // Expired certificate, delete the entry. + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { + return cacheKey, nil, nil, nil + } + } + + if session.vers != VersionTLS13 { + // In TLS 1.2 the cipher suite must match the resumed session. Ensure we + // are still offering it. + if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { + return cacheKey, nil, nil, nil + } + + hello.sessionTicket = session.sessionTicket + return + } + + // Check that the session ticket is not expired. + if c.config.time().After(session.useBy) { + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + + // In TLS 1.3 the KDF hash must match the resumed session. Ensure we + // offer at least one cipher suite with that hash. + cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) + if cipherSuite == nil { + return cacheKey, nil, nil, nil + } + cipherSuiteOk := false + for _, offeredID := range hello.cipherSuites { + offeredSuite := cipherSuiteTLS13ByID(offeredID) + if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return cacheKey, nil, nil, nil + } + + // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. + ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) + identity := pskIdentity{ + label: session.sessionTicket, + obfuscatedTicketAge: ticketAge + session.ageAdd, + } + hello.pskIdentities = []pskIdentity{identity} + hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} + + // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. + psk := cipherSuite.expandLabel(session.masterSecret, "resumption", + session.nonce, cipherSuite.hash.Size()) + earlySecret = cipherSuite.extract(psk, nil) + binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) + if c.extraConfig != nil { + hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0 + } + transcript := cipherSuite.hash.New() + transcript.Write(hello.marshalWithoutBinders()) + pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} + hello.updateBinders(pskBinders) + + if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { + c.extraConfig.SetAppDataFromSessionState(appData) + } + return +} + +func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { + peerVersion := serverHello.vers + if serverHello.supportedVersion != 0 { + peerVersion = serverHello.supportedVersion + } + + vers, ok := c.config.mutualVersion(roleClient, []uint16{peerVersion}) + if !ok { + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) + } + + c.vers = vers + c.haveVers = true + c.in.version = vers + c.out.version = vers + + return nil +} + +// Does the handshake, either a full one or resumes old session. Requires hs.c, +// hs.hello, hs.serverHello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() error { + c := hs.c + + isResume, err := hs.processServerHello() + if err != nil { + return err + } + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + + // No signatures of the handshake are needed in a resumption. + // Otherwise, in a full handshake, if we don't have any certificates + // configured then we will never send a CertificateVerify message and + // thus no signatures are needed in that case either. + if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { + hs.finishedHash.discardHandshakeBuffer() + } + + hs.finishedHash.Write(hs.hello.marshal()) + hs.finishedHash.Write(hs.serverHello.marshal()) + + c.buffering = true + c.didResume = isResume + if isResume { + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = false + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } else { + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = true + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + + hs.c.cipherSuite = hs.suite.id + return nil +} + +func (hs *clientHandshakeState) doFullHandshake() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + certMsg, ok := msg.(*certificateMsg) + if !ok || len(certMsg.certificates) == 0 { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + msg, err = c.readHandshake() + if err != nil { + return err + } + + cs, ok := msg.(*certificateStatusMsg) + if ok { + // RFC4366 on Certificate Status Request: + // The server MAY return a "certificate_status" message. + + if !hs.serverHello.ocspStapling { + // If a server returns a "CertificateStatus" message, then the + // server MUST have included an extension of type "status_request" + // with empty "extension_data" in the extended server hello. + + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received unexpected CertificateStatus message") + } + hs.finishedHash.Write(cs.marshal()) + + c.ocspResponse = cs.response + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + if c.handshakes == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := c.verifyServerCertificate(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + + keyAgreement := hs.suite.ka(c.vers) + + skx, ok := msg.(*serverKeyExchangeMsg) + if ok { + hs.finishedHash.Write(skx.marshal()) + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) + if err != nil { + c.sendAlert(alertUnexpectedMessage) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + var chainToSend *Certificate + var certRequested bool + certReq, ok := msg.(*certificateRequestMsg) + if ok { + certRequested = true + hs.finishedHash.Write(certReq.marshal()) + + cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) + if chainToSend, err = c.getClientCertificate(cri); err != nil { + c.sendAlert(alertInternalError) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + shd, ok := msg.(*serverHelloDoneMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) + } + hs.finishedHash.Write(shd.marshal()) + + // If the server requested a certificate then we have to send a + // Certificate message, even if it's empty because we don't have a + // certificate to send. + if certRequested { + certMsg = new(certificateMsg) + certMsg.certificates = chainToSend.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + } + + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ckx != nil { + hs.finishedHash.Write(ckx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + return err + } + } + + if chainToSend != nil && len(chainToSend.Certificate) > 0 { + certVerify := &certificateVerifyMsg{} + + key, ok := chainToSend.PrivateKey.(crypto.Signer) + if !ok { + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + certVerify.hasSignatureAlgorithm = true + certVerify.signatureAlgorithm = signatureAlgorithm + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.finishedHash.Write(certVerify.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + return err + } + } + + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to write to key log: " + err.Error()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher any + var clientHash, serverHash hash.Hash + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) + return nil +} + +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} + +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if err := hs.pickCipherSuite(); err != nil { + return false, err + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") + } + + if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { + c.secureRenegotiation = true + if len(hs.serverHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: initial handshake had non-empty renegotiation extension") + } + } + + if c.handshakes > 0 && c.secureRenegotiation { + var expectedSecureRenegotiation [24]byte + copy(expectedSecureRenegotiation[:], c.clientFinished[:]) + copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) + if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: incorrect renegotiation extension contents") + } + } + + if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil { + c.sendAlert(alertUnsupportedExtension) + return false, err + } + c.clientProtocol = hs.serverHello.alpnProtocol + + c.scts = hs.serverHello.scts + + if !hs.serverResumedSession() { + return false, nil + } + + if hs.session.vers != c.vers { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different version") + } + + if hs.session.cipherSuite != hs.suite.id { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different cipher suite") + } + + // Restore masterSecret, peerCerts, and ocspResponse from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + // Let the ServerHello SCTs override the session SCTs from the original + // connection, if any are provided + if len(c.scts) == 0 && len(hs.session.scts) != 0 { + c.scts = hs.session.scts + } + + return true, nil +} + +// checkALPN ensure that the server's choice of ALPN protocol is compatible with +// the protocols that we advertised in the Client Hello. +func checkALPN(clientProtos []string, serverProto string) error { + if serverProto == "" { + return nil + } + if len(clientProtos) == 0 { + return errors.New("tls: server advertised unrequested ALPN extension") + } + for _, proto := range clientProtos { + if proto == serverProto { + return nil + } + } + return errors.New("tls: server selected unadvertised ALPN protocol") +} + +func (hs *clientHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + + verify := hs.finishedHash.serverSum(hs.masterSecret) + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") + } + hs.finishedHash.Write(serverFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + if !hs.serverHello.ticketSupported { + return nil + } + + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + hs.finishedHash.Write(sessionTicketMsg.marshal()) + + hs.session = &clientSessionState{ + sessionTicket: sessionTicketMsg.ticket, + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + return nil +} + +func (hs *clientHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + copy(out, finished.verifyData) + return nil +} + +// verifyServerCertificate parses and verifies the provided chain, setting +// c.verifiedChains and c.peerCertificates or sending the appropriate alert. +func (c *Conn) verifyServerCertificate(certificates [][]byte) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + if !c.config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + var err error + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + switch certs[0].PublicKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: + break + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) + } + + c.peerCertificates = certs + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS +// <= 1.2 CertificateRequest, making an effort to fill in missing information. +func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { + cri := &certificateRequestInfo{ + AcceptableCAs: certReq.certificateAuthorities, + Version: vers, + ctx: ctx, + } + + var rsaAvail, ecAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case certTypeRSASign: + rsaAvail = true + case certTypeECDSASign: + ecAvail = true + } + } + + if !certReq.hasSignatureAlgorithm { + // Prior to TLS 1.2, signature schemes did not exist. In this case we + // make up a list based on the acceptable certificate types, to help + // GetClientCertificate and SupportsCertificate select the right certificate. + // The hash part of the SignatureScheme is a lie here, because + // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. + switch { + case rsaAvail && ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case rsaAvail: + cri.SignatureSchemes = []SignatureScheme{ + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + } + } + return toCertificateRequestInfo(cri) + } + + // Filter the signature schemes based on the certificate types. + // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). + cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) + for _, sigScheme := range certReq.supportedSignatureAlgorithms { + sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) + if err != nil { + continue + } + switch sigType { + case signatureECDSA, signatureEd25519: + if ecAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + case signatureRSAPSS, signaturePKCS1v15: + if rsaAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + } + } + + return toCertificateRequestInfo(cri) +} + +func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { + if c.config.GetClientCertificate != nil { + return c.config.GetClientCertificate(cri) + } + + for _, chain := range c.config.Certificates { + if err := cri.SupportsCertificate(&chain); err != nil { + continue + } + return &chain, nil + } + + // No acceptable certificate found. Don't send a certificate. + return new(Certificate), nil +} + +const clientSessionCacheKeyPrefix = "qtls-" + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *config) string { + if len(config.ServerName) > 0 { + return clientSessionCacheKeyPrefix + config.ServerName + } + return clientSessionCacheKeyPrefix + serverAddr.String() +} + +// hostnameInSNI converts name into an appropriate hostname for SNI. +// Literal IP addresses and absolute FQDNs are not permitted as SNI values. +// See RFC 6066, Section 3. +func hostnameInSNI(name string) string { + host := name + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } + if i := strings.LastIndex(host, "%"); i > 0 { + host = host[:i] + } + if net.ParseIP(host) != nil { + return "" + } + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return name +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client_tls13.go new file mode 100644 index 00000000..0de59fc1 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client_tls13.go @@ -0,0 +1,732 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/hmac" + "crypto/rsa" + "encoding/binary" + "errors" + "hash" + "sync/atomic" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +type clientHandshakeStateTLS13 struct { + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheParams ecdheParameters + + session *clientSessionState + earlySecret []byte + binderKey []byte + + certReq *certificateRequestMsgTLS13 + usingPSK bool + sentDummyCCS bool + suite *cipherSuiteTLS13 + transcript hash.Hash + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 +} + +// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. +func (hs *clientHandshakeStateTLS13) handshake() error { + c := hs.c + + // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, + // sections 4.1.2 and 4.1.3. + if c.handshakes > 0 { + c.sendAlert(alertProtocolVersion) + return errors.New("tls: server selected TLS 1.3 in a renegotiation") + } + + // Consistency check on the presence of a keyShare and its parameters. + if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + return c.sendAlert(alertInternalError) + } + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + hs.transcript = hs.suite.hash.New() + hs.transcript.Write(hs.hello.marshal()) + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.processHelloRetryRequest(); err != nil { + return err + } + } + + hs.transcript.Write(hs.serverHello.marshal()) + + c.buffering = true + if err := hs.processServerHello(); err != nil { + return err + } + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.establishHandshakeKeys(); err != nil { + return err + } + if err := hs.readServerParameters(); err != nil { + return err + } + if err := hs.readServerCertificate(); err != nil { + return err + } + if err := hs.readServerFinished(); err != nil { + return err + } + if err := hs.sendClientCertificate(); err != nil { + return err + } + if err := hs.sendClientFinished(); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// checkServerHelloOrHRR does validity checks that apply to both ServerHello and +// HelloRetryRequest messages. It sets hs.suite. +func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { + c := hs.c + + if hs.serverHello.supportedVersion == 0 { + c.sendAlert(alertMissingExtension) + return errors.New("tls: server selected TLS 1.3 using the legacy version field") + } + + if hs.serverHello.supportedVersion != VersionTLS13 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid version after a HelloRetryRequest") + } + + if hs.serverHello.vers != VersionTLS12 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an incorrect legacy version") + } + + if hs.serverHello.ocspStapling || + hs.serverHello.ticketSupported || + hs.serverHello.secureRenegotiationSupported || + len(hs.serverHello.secureRenegotiation) != 0 || + len(hs.serverHello.alpnProtocol) != 0 || + len(hs.serverHello.scts) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") + } + + if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not echo the legacy session ID") + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported compression format") + } + + selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) + if hs.suite != nil && selectedSuite != hs.suite { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server changed cipher suite after a HelloRetryRequest") + } + if selectedSuite == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server chose an unconfigured cipher suite") + } + hs.suite = selectedSuite + c.cipherSuite = hs.suite.id + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and +// resends hs.hello, and reads the new ServerHello into hs.serverHello. +func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. (The idea is that the server might offload transcript + // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + hs.transcript.Write(hs.serverHello.marshal()) + + // The only HelloRetryRequest extensions we support are key_share and + // cookie, and clients must abort the handshake if the HRR would not result + // in any change in the ClientHello. + if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest message") + } + + if hs.serverHello.cookie != nil { + hs.hello.cookie = hs.serverHello.cookie + } + + if hs.serverHello.serverShare.group != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received malformed key_share extension") + } + + // If the server sent a key_share extension selecting a group, ensure it's + // a group we advertised but did not send a key share for, and send a key + // share for it this time. + if curveID := hs.serverHello.selectedGroup; curveID != 0 { + curveOK := false + for _, id := range hs.hello.supportedCurves { + if id == curveID { + curveOK = true + break + } + } + if !curveOK { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + if hs.ecdheParams.CurveID() == curveID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), curveID) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.ecdheParams = params + hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + hs.hello.raw = nil + if len(hs.hello.pskIdentities) > 0 { + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash == hs.suite.hash { + // Update binders and obfuscated_ticket_age. + ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) + hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd + + transcript := hs.suite.hash.New() + transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + transcript.Write(chHash) + transcript.Write(hs.serverHello.marshal()) + transcript.Write(hs.hello.marshalWithoutBinders()) + pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} + hs.hello.updateBinders(pskBinders) + } else { + // Server selected a cipher suite incompatible with the PSK. + hs.hello.pskIdentities = nil + hs.hello.pskBinders = nil + } + } + + if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { + c.extraConfig.Rejected0RTT() + } + hs.hello.earlyData = false // disable 0-RTT + + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + hs.serverHello = serverHello + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) processServerHello() error { + c := hs.c + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: server sent two HelloRetryRequest messages") + } + + if len(hs.serverHello.cookie) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a cookie in a normal ServerHello") + } + + if hs.serverHello.selectedGroup != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: malformed key_share extension") + } + + if hs.serverHello.serverShare.group == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not send a key share") + } + if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + + if !hs.serverHello.selectedIdentityPresent { + return nil + } + + if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK") + } + + if len(hs.hello.pskIdentities) != 1 || hs.session == nil { + return c.sendAlert(alertInternalError) + } + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash != hs.suite.hash { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK and cipher suite pair") + } + + hs.usingPSK = true + c.didResume = true + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + c.scts = hs.session.scts + return nil +} + +func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { + c := hs.c + + sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) + if sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + + earlySecret := hs.earlySecret + if !hs.usingPSK { + earlySecret = hs.suite.extract(nil, nil) + } + handshakeSecret := hs.suite.extract(sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret) + c.out.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(handshakeSecret, "derived", nil)) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerParameters() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(encryptedExtensions, msg) + } + // Notify the caller if 0-RTT was rejected. + if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { + c.extraConfig.Rejected0RTT() + } + c.used0RTT = encryptedExtensions.earlyData + if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil { + hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions) + } + hs.transcript.Write(encryptedExtensions.marshal()) + + if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { + c.sendAlert(alertUnsupportedExtension) + return err + } + c.clientProtocol = encryptedExtensions.alpnProtocol + + if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection { + if len(encryptedExtensions.alpnProtocol) == 0 { + // the server didn't select an ALPN + c.sendAlert(alertNoApplicationProtocol) + return errors.New("ALPN negotiation failed. Server didn't offer any protocols") + } + } + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerCertificate() error { + c := hs.c + + // Either a PSK or a certificate is always used, but not both. + // See RFC 8446, Section 4.1.1. + if hs.usingPSK { + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certReq, ok := msg.(*certificateRequestMsgTLS13) + if ok { + hs.transcript.Write(certReq.marshal()) + + hs.certReq = certReq + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + if len(certMsg.certificate.Certificate) == 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received empty certificates message") + } + hs.transcript.Write(certMsg.marshal()) + + c.scts = certMsg.certificate.SignedCertificateTimestamps + c.ocspResponse = certMsg.certificate.OCSPStaple + + if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + if !hmac.Equal(expectedMAC, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid server finished hash") + } + + hs.transcript.Write(finished.marshal()) + + // Derive secrets that take context through the server Finished. + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionApplication, hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { + c := hs.c + + if hs.certReq == nil { + return nil + } + + cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{ + AcceptableCAs: hs.certReq.certificateAuthorities, + SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, + Version: c.vers, + ctx: hs.ctx, + })) + if err != nil { + return err + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *cert + certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + // If we sent an empty certificate message, skip the CertificateVerify. + if len(cert.Certificate) == 0 { + return nil + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + + certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) + if err != nil { + // getClientCertificate returned a certificate incompatible with the + // CertificateRequestInfo supported signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) + c.out.setTrafficSecret(hs.suite, hs.trafficSecret) + + if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + } + + return nil +} + +func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { + if !c.isClient { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received new session ticket from a client") + } + + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return nil + } + + // See RFC 8446, Section 4.6.1. + if msg.lifetime == 0 { + return nil + } + lifetime := time.Duration(msg.lifetime) * time.Second + if lifetime > maxSessionTicketLifetime { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: received a session ticket with invalid lifetime") + } + + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil || c.resumptionSecret == nil { + return c.sendAlert(alertInternalError) + } + + // We need to save the max_early_data_size that the server sent us, in order + // to decide if we're going to try 0-RTT with this ticket. + // However, at the same time, the qtls.ClientSessionTicket needs to be equal to + // the tls.ClientSessionTicket, so we can't just add a new field to the struct. + // We therefore abuse the nonce field (which is a byte slice) + nonceWithEarlyData := make([]byte, len(msg.nonce)+4) + binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData) + copy(nonceWithEarlyData[4:], msg.nonce) + + var appData []byte + if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil { + appData = c.extraConfig.GetAppDataForSessionState() + } + var b cryptobyte.Builder + b.AddUint16(clientSessionStateVersion) // revision + b.AddUint32(msg.maxEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(appData) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(msg.nonce) + }) + + // Save the resumption_master_secret and nonce instead of deriving the PSK + // to do the least amount of work on NewSessionTicket messages before we + // know if the ticket will be used. Forward secrecy of resumed connections + // is guaranteed by the requirement for pskModeDHE. + session := &clientSessionState{ + sessionTicket: msg.label, + vers: c.vers, + cipherSuite: c.cipherSuite, + masterSecret: c.resumptionSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + nonce: b.BytesOrPanic(), + useBy: c.config.time().Add(lifetime), + ageAdd: msg.ageAdd, + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session)) + + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_messages.go b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_messages.go new file mode 100644 index 00000000..5f87d4b8 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_messages.go @@ -0,0 +1,1831 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "fmt" + "strings" + + "golang.org/x/crypto/cryptobyte" +) + +// The marshalingFunction type is an adapter to allow the use of ordinary +// functions as cryptobyte.MarshalingValue. +type marshalingFunction func(b *cryptobyte.Builder) error + +func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { + return f(b) +} + +// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If +// the length of the sequence is not the value specified, it produces an error. +func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { + b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { + if len(v) != n { + return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) + } + b.AddBytes(v) + return nil + })) +} + +// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. +func addUint64(b *cryptobyte.Builder, v uint64) { + b.AddUint32(uint32(v >> 32)) + b.AddUint32(uint32(v)) +} + +// readUint64 decodes a big-endian, 64-bit value into out and advances over it. +// It reports whether the read was successful. +func readUint64(s *cryptobyte.String, out *uint64) bool { + var hi, lo uint32 + if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { + return false + } + *out = uint64(hi)<<32 | uint64(lo) + return true +} + +// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) +} + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + serverName string + ocspStapling bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocols []string + scts bool + supportedVersions []uint16 + cookie []byte + keyShares []keyShare + earlyData bool + pskModes []uint8 + pskIdentities []pskIdentity + pskBinders [][]byte + additionalExtensions []Extension +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeClientHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + b.AddUint16(suite) + } + }) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.compressionMethods) + }) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + b.AddUint16(extensionServerName) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // name_type = host_name + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(1) // status_type = ocsp + b.AddUint16(0) // empty responder_id_list + b.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + b.AddUint16(extensionSupportedCurves) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + b.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + b.AddUint16(extensionSessionTicket) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + b.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ks := range m.keyShares { + b.AddUint16(uint16(ks.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + b.AddUint16(extensionPSKModes) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.pskModes) + }) + }) + } + for _, ext := range m.additionalExtensions { + b.AddUint16(ext.Type) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ext.Data) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(psk.label) + }) + b.AddUint32(psk.obfuscatedTicketAge) + } + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +// marshalWithoutBinders returns the ClientHello through the +// PreSharedKeyExtension.identities field, according to RFC 8446, Section +// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. +func (m *clientHelloMsg) marshalWithoutBinders() []byte { + bindersLen := 2 // uint16 length prefix + for _, binder := range m.pskBinders { + bindersLen += 1 // uint8 length prefix + bindersLen += len(binder) + } + + fullMessage := m.marshal() + return fullMessage[:len(fullMessage)-bindersLen] +} + +// updateBinders updates the m.pskBinders field, if necessary updating the +// cached marshaled representation. The supplied binders must have the same +// length as the current m.pskBinders. +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { + if len(pskBinders) != len(m.pskBinders) { + panic("tls: internal error: pskBinders length mismatch") + } + for i := range m.pskBinders { + if len(pskBinders[i]) != len(m.pskBinders[i]) { + panic("tls: internal error: pskBinders length mismatch") + } + } + m.pskBinders = pskBinders + if m.raw != nil { + lenWithoutBinders := len(m.marshalWithoutBinders()) + b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { + panic("tls: internal error: failed to update binders") + } + } +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + *m = clientHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) { + return false + } + + var cipherSuites cryptobyte.String + if !s.ReadUint16LengthPrefixed(&cipherSuites) { + return false + } + m.cipherSuites = []uint16{} + m.secureRenegotiationSupported = false + for !cipherSuites.Empty() { + var suite uint16 + if !cipherSuites.ReadUint16(&suite) { + return false + } + if suite == scsvRenegotiation { + m.secureRenegotiationSupported = true + } + m.cipherSuites = append(m.cipherSuites, suite) + } + + if !readUint8LengthPrefixed(&s, &m.compressionMethods) { + return false + } + + if s.Empty() { + // ClientHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var ext uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&ext) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch ext { + case extensionServerName: + // RFC 6066, Section 3 + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { + return false + } + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { + return false + } + if nameType != 0 { + continue + } + if len(m.serverName) != 0 { + // Multiple names of the same name_type are prohibited. + return false + } + m.serverName = string(serverName) + // An SNI value may not include a trailing dot. + if strings.HasSuffix(m.serverName, ".") { + return false + } + } + case extensionStatusRequest: + // RFC 4366, Section 3.6 + var statusType uint8 + var ignored cryptobyte.String + if !extData.ReadUint8(&statusType) || + !extData.ReadUint16LengthPrefixed(&ignored) || + !extData.ReadUint16LengthPrefixed(&ignored) { + return false + } + m.ocspStapling = statusType == statusTypeOCSP + case extensionSupportedCurves: + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + var curves cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { + return false + } + for !curves.Empty() { + var curve uint16 + if !curves.ReadUint16(&curve) { + return false + } + m.supportedCurves = append(m.supportedCurves, CurveID(curve)) + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + case extensionSessionTicket: + // RFC 5077, Section 3.2 + m.ticketSupported = true + extData.ReadBytes(&m.sessionTicket, len(extData)) + case extensionSignatureAlgorithms: + // RFC 5246, Section 7.4.1.4.1 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + // RFC 8446, Section 4.2.3 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionRenegotiationInfo: + // RFC 5746, Section 3.2 + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + // RFC 7301, Section 3.1 + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + for !protoList.Empty() { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(proto)) + } + case extensionSCT: + // RFC 6962, Section 3.3.1 + m.scts = true + case extensionSupportedVersions: + // RFC 8446, Section 4.2.1 + var versList cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { + return false + } + for !versList.Empty() { + var vers uint16 + if !versList.ReadUint16(&vers) { + return false + } + m.supportedVersions = append(m.supportedVersions, vers) + } + case extensionCookie: + // RFC 8446, Section 4.2.2 + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // RFC 8446, Section 4.2.8 + var clientShares cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&clientShares) { + return false + } + for !clientShares.Empty() { + var ks keyShare + if !clientShares.ReadUint16((*uint16)(&ks.group)) || + !readUint16LengthPrefixed(&clientShares, &ks.data) || + len(ks.data) == 0 { + return false + } + m.keyShares = append(m.keyShares, ks) + } + case extensionEarlyData: + // RFC 8446, Section 4.2.10 + m.earlyData = true + case extensionPSKModes: + // RFC 8446, Section 4.2.9 + if !readUint8LengthPrefixed(&extData, &m.pskModes) { + return false + } + case extensionPreSharedKey: + // RFC 8446, Section 4.2.11 + if !extensions.Empty() { + return false // pre_shared_key must be the last extension + } + var identities cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { + return false + } + for !identities.Empty() { + var psk pskIdentity + if !readUint16LengthPrefixed(&identities, &psk.label) || + !identities.ReadUint32(&psk.obfuscatedTicketAge) || + len(psk.label) == 0 { + return false + } + m.pskIdentities = append(m.pskIdentities, psk) + } + var binders cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { + return false + } + for !binders.Empty() { + var binder []byte + if !readUint8LengthPrefixed(&binders, &binder) || + len(binder) == 0 { + return false + } + m.pskBinders = append(m.pskBinders, binder) + } + default: + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + ocspStapling bool + ticketSupported bool + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocol string + scts [][]byte + supportedVersion uint16 + serverShare keyShare + selectedIdentityPresent bool + selectedIdentity uint16 + supportedPoints []uint8 + + // HelloRetryRequest extensions + cookie []byte + selectedGroup CurveID +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeServerHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16(m.cipherSuite) + b.AddUint8(m.compressionMethod) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + b.AddUint16(extensionSessionTicket) + b.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range m.scts { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.serverShare.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *serverHelloMsg) unmarshal(data []byte) bool { + *m = serverHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) || + !s.ReadUint16(&m.cipherSuite) || + !s.ReadUint8(&m.compressionMethod) { + return false + } + + if s.Empty() { + // ServerHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSessionTicket: + m.ticketSupported = true + case extensionRenegotiationInfo: + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + m.scts = append(m.scts, sct) + } + case extensionSupportedVersions: + if !extData.ReadUint16(&m.supportedVersion) { + return false + } + case extensionCookie: + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // This extension has different formats in SH and HRR, accept either + // and let the handshake logic decide. See RFC 8446, Section 4.2.8. + if len(extData) == 2 { + if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { + return false + } + } else { + if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || + !readUint16LengthPrefixed(&extData, &m.serverShare.data) { + return false + } + } + case extensionPreSharedKey: + m.selectedIdentityPresent = true + if !extData.ReadUint16(&m.selectedIdentity) { + return false + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type encryptedExtensionsMsg struct { + raw []byte + alpnProtocol string + earlyData bool + + additionalExtensions []Extension +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeEncryptedExtensions) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + for _, ext := range m.additionalExtensions { + b.AddUint16(ext.Type) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ext.Data) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + *m = encryptedExtensionsMsg{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var ext uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&ext) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch ext { + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionEarlyData: + m.earlyData = true + default: + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type endOfEarlyDataMsg struct{} + +func (m *endOfEarlyDataMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeEndOfEarlyData + return x +} + +func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type keyUpdateMsg struct { + raw []byte + updateRequested bool +} + +func (m *keyUpdateMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeKeyUpdate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.updateRequested { + b.AddUint8(1) + } else { + b.AddUint8(0) + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *keyUpdateMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var updateRequested uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&updateRequested) || !s.Empty() { + return false + } + switch updateRequested { + case 0: + m.updateRequested = false + case 1: + m.updateRequested = true + default: + return false + } + return true +} + +type newSessionTicketMsgTLS13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + nonce []byte + label []byte + maxEarlyData uint32 +} + +func (m *newSessionTicketMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeNewSessionTicket) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.lifetime) + b.AddUint32(m.ageAdd) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.nonce) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.label) + }) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.maxEarlyData > 0 { + b.AddUint16(extensionEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.maxEarlyData) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { + *m = newSessionTicketMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint32(&m.lifetime) || + !s.ReadUint32(&m.ageAdd) || + !readUint8LengthPrefixed(&s, &m.nonce) || + !readUint16LengthPrefixed(&s, &m.label) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionEarlyData: + if !extData.ReadUint32(&m.maxEarlyData) { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateRequestMsgTLS13 struct { + raw []byte + ocspStapling bool + scts bool + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateRequest) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + // certificate_request_context (SHALL be zero length unless used for + // post-handshake authentication) + b.AddUint8(0) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.scts { + // RFC 8446, Section 4.4.2.1 makes no mention of + // signed_certificate_timestamp in CertificateRequest, but + // "Extensions in the Certificate message from the client MUST + // correspond to extensions in the CertificateRequest message + // from the server." and it appears in the table in Section 4.2. + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedSignatureAlgorithms) > 0 { + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.certificateAuthorities) > 0 { + b.AddUint16(extensionCertificateAuthorities) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ca) + }) + } + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { + *m = certificateRequestMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context, extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSCT: + m.scts = true + case extensionSignatureAlgorithms: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionCertificateAuthorities: + var auths cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { + return false + } + for !auths.Empty() { + var ca []byte + if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { + return false + } + m.certificateAuthorities = append(m.certificateAuthorities, ca) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return true +} + +type certificateMsgTLS13 struct { + raw []byte + certificate Certificate + ocspStapling bool + scts bool +} + +func (m *certificateMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // certificate_request_context + + certificate := m.certificate + if !m.ocspStapling { + certificate.OCSPStaple = nil + } + if !m.scts { + certificate.SignedCertificateTimestamps = nil + } + marshalCertificate(b, certificate) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for i, cert := range certificate.Certificate { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if i > 0 { + // This library only supports OCSP and SCT for leaf certificates. + return + } + if certificate.OCSPStaple != nil { + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(certificate.OCSPStaple) + }) + }) + } + if certificate.SignedCertificateTimestamps != nil { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range certificate.SignedCertificateTimestamps { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + }) + } + }) +} + +func (m *certificateMsgTLS13) unmarshal(data []byte) bool { + *m = certificateMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !unmarshalCertificate(&s, &m.certificate) || + !s.Empty() { + return false + } + + m.scts = m.certificate.SignedCertificateTimestamps != nil + m.ocspStapling = m.certificate.OCSPStaple != nil + + return true +} + +func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + var extensions cryptobyte.String + if !readUint24LengthPrefixed(&certList, &cert) || + !certList.ReadUint16LengthPrefixed(&extensions) { + return false + } + certificate.Certificate = append(certificate.Certificate, cert) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + if len(certificate.Certificate) > 1 { + // This library only supports OCSP and SCT for leaf certificates. + continue + } + + switch extension { + case extensionStatusRequest: + var statusType uint8 + if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || + len(certificate.OCSPStaple) == 0 { + return false + } + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + certificate.SignedCertificateTimestamps = append( + certificate.SignedCertificateTimestamps, sct) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + } + return true +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.key = data[4:] + return true +} + +type certificateStatusMsg struct { + raw []byte + response []byte +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateStatus) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.response) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var statusType uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&s, &m.response) || + len(m.response) == 0 || !s.Empty() { + return false + } + return true +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return false + } + m.ciphertext = data[4:] + return true +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeFinished) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.verifyData) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + return s.Skip(1) && + readUint24LengthPrefixed(&s, &m.verifyData) && + s.Empty() +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAlgorithm indicates whether this message includes a list of + // supported signature algorithms. This change was introduced with TLS 1.2. + hasSignatureAlgorithm bool + + certificateTypes []byte + supportedSignatureAlgorithms []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 4346, Section 7.4.4. + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAlgorithm { + length += 2 + 2*len(m.supportedSignatureAlgorithms) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAlgorithm { + n := len(m.supportedSignatureAlgorithms) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAlgo := range m.supportedSignatureAlgorithms { + y[0] = uint8(sigAlgo >> 8) + y[1] = uint8(sigAlgo) + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return false + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return false + } + + data = data[numCertTypes:] + + if m.hasSignatureAlgorithm { + if len(data) < 2 { + return false + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return false + } + if len(data) < int(sigAndHashLen) { + return false + } + numSigAlgos := sigAndHashLen / 2 + m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) + for i := range m.supportedSignatureAlgorithms { + m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) + data = data[2:] + } + } + + if len(data) < 2 { + return false + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return false + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return false + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + + return len(data) == 0 +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAlgorithm bool // format change introduced in TLS 1.2 + signatureAlgorithm SignatureScheme + signature []byte +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateVerify) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + b.AddUint16(uint16(m.signatureAlgorithm)) + } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.signature) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + if !s.Skip(4) { // message type and uint24 length field + return false + } + if m.hasSignatureAlgorithm { + if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { + return false + } + } + return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 5077, Section 3.3. + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 10 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return false + } + + m.ticket = data[10:] + + return true +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server.go b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server.go new file mode 100644 index 00000000..2fe82848 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server.go @@ -0,0 +1,905 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "sync/atomic" + "time" +) + +// serverHandshakeState contains details of a server handshake in progress. +// It's discarded once the handshake has completed. +type serverHandshakeState struct { + c *Conn + ctx context.Context + clientHello *clientHelloMsg + hello *serverHelloMsg + suite *cipherSuite + ecdheOk bool + ecSignOk bool + rsaDecryptOk bool + rsaSignOk bool + sessionState *sessionState + finishedHash finishedHash + masterSecret []byte + cert *Certificate +} + +// serverHandshake performs a TLS handshake as a server. +func (c *Conn) serverHandshake(ctx context.Context) error { + c.setAlternativeRecordLayer() + + clientHello, err := c.readClientHello(ctx) + if err != nil { + return err + } + + if c.vers == VersionTLS13 { + hs := serverHandshakeStateTLS13{ + c: c, + ctx: ctx, + clientHello: clientHello, + } + return hs.handshake() + } else if c.extraConfig.usesAlternativeRecordLayer() { + // This should already have been caught by the check that the ClientHello doesn't + // offer any (supported) versions older than TLS 1.3. + // Check again to make sure we can't be tricked into using an older version. + c.sendAlert(alertProtocolVersion) + return errors.New("tls: negotiated TLS < 1.3 when using QUIC") + } + + hs := serverHandshakeState{ + c: c, + ctx: ctx, + clientHello: clientHello, + } + return hs.handshake() +} + +func (hs *serverHandshakeState) handshake() error { + c := hs.c + + if err := hs.processClientHello(); err != nil { + return err + } + + // For an overview of TLS handshaking, see RFC 5246, Section 7.3. + c.buffering = true + if hs.checkForResumption() { + // The client has included a session ticket and so we do an abbreviated handshake. + c.didResume = true + if err := hs.doResumeHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(c.serverFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = false + if err := hs.readFinished(nil); err != nil { + return err + } + } else { + // The client didn't include a session ticket, or it wasn't + // valid so we do a full handshake. + if err := hs.pickCipherSuite(); err != nil { + return err + } + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readFinished(c.clientFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = true + c.buffering = true + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(nil); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// readClientHello reads a ClientHello message and selects the protocol version. +func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { + msg, err := c.readHandshake() + if err != nil { + return nil, err + } + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return nil, unexpectedMessageError(clientHello, msg) + } + + var configForClient *config + originalConfig := c.config + if c.config.GetConfigForClient != nil { + chi := newClientHelloInfo(ctx, c, clientHello) + if cfc, err := c.config.GetConfigForClient(chi); err != nil { + c.sendAlert(alertInternalError) + return nil, err + } else if cfc != nil { + configForClient = fromConfig(cfc) + c.config = configForClient + } + } + c.ticketKeys = originalConfig.ticketKeys(configForClient) + + clientVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + clientVersions = supportedVersionsFromMax(clientHello.vers) + } + if c.extraConfig.usesAlternativeRecordLayer() { + // In QUIC, the client MUST NOT offer any old TLS versions. + // Here, we can only check that none of the other supported versions of this library + // (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here. + for _, ver := range clientVersions { + if ver == VersionTLS13 { + continue + } + for _, v := range supportedVersions { + if ver == v { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver) + } + } + } + // Make the config we're using allows us to use TLS 1.3. + if c.config.maxSupportedVersion(roleServer) < VersionTLS13 { + c.sendAlert(alertInternalError) + return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") + } + } + c.vers, ok = c.config.mutualVersion(roleServer, clientVersions) + if !ok { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) + } + c.haveVers = true + c.in.version = c.vers + c.out.version = c.vers + + return clientHello, nil +} + +func (hs *serverHandshakeState) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers + + foundCompression := false + // We only support null compression, so check that the client offered it. + for _, compression := range hs.clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true + break + } + } + + if !foundCompression { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client does not support uncompressed connections") + } + + hs.hello.random = make([]byte, 32) + serverRandom := hs.hello.random + // Downgrade protection canaries. See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion(roleServer) + if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { + if c.vers == VersionTLS12 { + copy(serverRandom[24:], downgradeCanaryTLS12) + } else { + copy(serverRandom[24:], downgradeCanaryTLS11) + } + serverRandom = serverRandom[:24] + } + _, err := io.ReadFull(c.config.rand(), serverRandom) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported + hs.hello.compressionMethod = compressionNone + if len(hs.clientHello.serverName) > 0 { + c.serverName = hs.clientHello.serverName + } + + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + if err != nil { + c.sendAlert(alertNoApplicationProtocol) + return err + } + hs.hello.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + + hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + if hs.clientHello.scts { + hs.hello.scts = hs.cert.SignedCertificateTimestamps + } + + hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) + + if hs.ecdheOk { + // Although omitting the ec_point_formats extension is permitted, some + // old OpenSSL version will refuse to handshake if not present. + // + // Per RFC 4492, section 5.1.2, implementations MUST support the + // uncompressed point format. See golang.org/issue/31943. + hs.hello.supportedPoints = []uint8{pointFormatUncompressed} + } + + if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { + switch priv.Public().(type) { + case *ecdsa.PublicKey: + hs.ecSignOk = true + case ed25519.PublicKey: + hs.ecSignOk = true + case *rsa.PublicKey: + hs.rsaSignOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) + } + } + if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { + switch priv.Public().(type) { + case *rsa.PublicKey: + hs.rsaDecryptOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) + } + } + + return nil +} + +// negotiateALPN picks a shared ALPN protocol that both sides support in server +// preference order. If ALPN is not configured or the peer doesn't support it, +// it returns "" and no error. +func negotiateALPN(serverProtos, clientProtos []string) (string, error) { + if len(serverProtos) == 0 || len(clientProtos) == 0 { + return "", nil + } + var http11fallback bool + for _, s := range serverProtos { + for _, c := range clientProtos { + if s == c { + return s, nil + } + if s == "h2" && c == "http/1.1" { + http11fallback = true + } + } + } + // As a special case, let http/1.1 clients connect to h2 servers as if they + // didn't support ALPN. We used not to enforce protocol overlap, so over + // time a number of HTTP servers were configured with only "h2", but + // expected to accept connections from "http/1.1" clients. See Issue 46310. + if http11fallback { + return "", nil + } + return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos) +} + +// supportsECDHE returns whether ECDHE key exchanges can be used with this +// pre-TLS 1.3 client. +func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool { + supportsCurve := false + for _, curve := range supportedCurves { + if c.supportsCurve(curve) { + supportsCurve = true + break + } + } + + supportsPointFormat := false + for _, pointFormat := range supportedPoints { + if pointFormat == pointFormatUncompressed { + supportsPointFormat = true + break + } + } + + return supportsCurve && supportsPointFormat +} + +func (hs *serverHandshakeState) pickCipherSuite() error { + c := hs.c + + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceOrder = cipherSuitesPreferenceOrderNoAES + } + + configCipherSuites := c.config.cipherSuites() + preferenceList := make([]uint16, 0, len(configCipherSuites)) + for _, suiteID := range preferenceOrder { + for _, id := range configCipherSuites { + if id == suiteID { + preferenceList = append(preferenceList, id) + break + } + } + } + + hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk) + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // The client is doing a fallback connection. See RFC 7507. + if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + return nil +} + +func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + if !hs.ecdheOk { + return false + } + if c.flags&suiteECSign != 0 { + if !hs.ecSignOk { + return false + } + } else if !hs.rsaSignOk { + return false + } + } else if !hs.rsaDecryptOk { + return false + } + if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true +} + +// checkForResumption reports whether we should perform resumption on this connection. +func (hs *serverHandshakeState) checkForResumption() bool { + c := hs.c + + if c.config.SessionTicketsDisabled { + return false + } + + plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) + if plaintext == nil { + return false + } + hs.sessionState = &sessionState{usedOldKey: usedOldKey} + ok := hs.sessionState.unmarshal(plaintext) + if !ok { + return false + } + + createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + return false + } + + // Never resume a session for a different TLS version. + if c.vers != hs.sessionState.vers { + return false + } + + cipherSuiteOk := false + // Check that the client is still offering the ciphersuite in the session. + for _, id := range hs.clientHello.cipherSuites { + if id == hs.sessionState.cipherSuite { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return false + } + + // Check that we also support the ciphersuite from the session. + hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, + c.config.cipherSuites(), hs.cipherSuiteOk) + if hs.suite == nil { + return false + } + + sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + return false + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + return false + } + + return true +} + +func (hs *serverHandshakeState) doResumeHandshake() error { + c := hs.c + + hs.hello.cipherSuite = hs.suite.id + c.cipherSuite = hs.suite.id + // We echo the client's session ID in the ServerHello to let it know + // that we're doing a resumption. + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.ticketSupported = hs.sessionState.usedOldKey + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := c.processCertsFromClient(Certificate{ + Certificate: hs.sessionState.certificates, + }); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + hs.masterSecret = hs.sessionState.masterSecret + + return nil +} + +func (hs *serverHandshakeState) doFullHandshake() error { + c := hs.c + + if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { + hs.hello.ocspStapling = true + } + + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled + hs.hello.cipherSuite = hs.suite.id + + hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) + if c.config.ClientAuth == NoClientCert { + // No need to keep a full record of the handshake if client + // certificates won't be used. + hs.finishedHash.discardHandshakeBuffer() + } + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + certMsg := new(certificateMsg) + certMsg.certificates = hs.cert.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + if hs.hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.response = hs.cert.OCSPStaple + hs.finishedHash.Write(certStatus.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + return err + } + } + + keyAgreement := hs.suite.ka(c.vers) + skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if skx != nil { + hs.finishedHash.Write(skx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + return err + } + } + + var certReq *certificateRequestMsg + if c.config.ClientAuth >= RequestClientCert { + // Request a client certificate + certReq = new(certificateRequestMsg) + certReq.certificateTypes = []byte{ + byte(certTypeRSASign), + byte(certTypeECDSASign), + } + if c.vers >= VersionTLS12 { + certReq.hasSignatureAlgorithm = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + } + + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + hs.finishedHash.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + helloDone := new(serverHelloDoneMsg) + hs.finishedHash.Write(helloDone.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + return err + } + + if _, err := c.flush(); err != nil { + return err + } + + var pub crypto.PublicKey // public key for client auth, if any + + msg, err := c.readHandshake() + if err != nil { + return err + } + + // If we requested a client certificate, then the client must send a + // certificate message, even if it's empty. + if c.config.ClientAuth >= RequestClientCert { + certMsg, ok := msg.(*certificateMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(Certificate{ + Certificate: certMsg.certificates, + }); err != nil { + return err + } + if len(certMsg.certificates) != 0 { + pub = c.peerCertificates[0].PublicKey + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + // Get client key exchange + ckx, ok := msg.(*clientKeyExchangeMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) + } + hs.finishedHash.Write(ckx.marshal()) + + preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + + // If we received a client cert in response to our certificate request message, + // the client will send us a certificateVerifyMsg immediately after the + // clientKeyExchangeMsg. This message is a digest of all preceding + // handshake-layer messages that is signed using the private key corresponding + // to the client's certificate. This allows us to verify that the client is in + // possession of the private key of the certificate. + if len(c.peerCertificates) > 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.finishedHash.Write(certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *serverHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + + var clientCipher, serverCipher any + var clientHash, serverHash hash.Hash + + if hs.suite.aead == nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) + c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) + + return nil +} + +func (hs *serverHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + + verify := hs.finishedHash.clientSum(hs.masterSecret) + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client's Finished message is incorrect") + } + + hs.finishedHash.Write(clientFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket() error { + // ticketSupported is set in a resumption handshake if the + // ticket from the client was encrypted with an old session + // ticket key and thus a refreshed ticket should be sent. + if !hs.hello.ticketSupported { + return nil + } + + c := hs.c + m := new(newSessionTicketMsg) + + createdAt := uint64(c.config.time().Unix()) + if hs.sessionState != nil { + // If this is re-wrapping an old key, then keep + // the original time it was created. + createdAt = hs.sessionState.createdAt + } + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + createdAt: createdAt, + masterSecret: hs.masterSecret, + certificates: certsFromClient, + } + var err error + m.ticket, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + + hs.finishedHash.Write(m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + copy(out, finished.verifyData) + + return nil +} + +// processCertsFromClient takes a chain of client certificates either from a +// Certificates message or from a sessionState and verifies them. It returns +// the public key of the leaf certificate. +func (c *Conn) processCertsFromClient(certificate Certificate) error { + certificates := certificate.Certificate + certs := make([]*x509.Certificate, len(certificates)) + var err error + for i, asn1Data := range certificates { + if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse client certificate: " + err.Error()) + } + } + + if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: client didn't provide a certificate") + } + + if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { + opts := x509.VerifyOptions{ + Roots: c.config.ClientCAs, + CurrentTime: c.config.time(), + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + chains, err := certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to verify client certificate: " + err.Error()) + } + + c.verifiedChains = chains + } + + c.peerCertificates = certs + c.ocspResponse = certificate.OCSPStaple + c.scts = certificate.SignedCertificateTimestamps + + if len(certs) > 0 { + switch certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + } + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { + supportedVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + supportedVersions = supportedVersionsFromMax(clientHello.vers) + } + + return toClientHelloInfo(&clientHelloInfo{ + CipherSuites: clientHello.cipherSuites, + ServerName: clientHello.serverName, + SupportedCurves: clientHello.supportedCurves, + SupportedPoints: clientHello.supportedPoints, + SignatureSchemes: clientHello.supportedSignatureAlgorithms, + SupportedProtos: clientHello.alpnProtocols, + SupportedVersions: supportedVersions, + Conn: c.conn, + config: toConfig(c.config), + ctx: ctx, + }) +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server_tls13.go new file mode 100644 index 00000000..dd8d801e --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server_tls13.go @@ -0,0 +1,896 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/hmac" + "crypto/rsa" + "errors" + "hash" + "io" + "sync/atomic" + "time" +) + +// maxClientPSKIdentities is the number of client PSK identities the server will +// attempt to validate. It will ignore the rest not to let cheap ClientHello +// messages cause too much work in session ticket decryption attempts. +const maxClientPSKIdentities = 5 + +type serverHandshakeStateTLS13 struct { + c *Conn + ctx context.Context + clientHello *clientHelloMsg + hello *serverHelloMsg + alpnNegotiationErr error + encryptedExtensions *encryptedExtensionsMsg + sentDummyCCS bool + usingPSK bool + suite *cipherSuiteTLS13 + cert *Certificate + sigAlg SignatureScheme + earlySecret []byte + sharedKey []byte + handshakeSecret []byte + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 + transcript hash.Hash + clientFinished []byte +} + +func (hs *serverHandshakeStateTLS13) handshake() error { + c := hs.c + + // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. + if err := hs.processClientHello(); err != nil { + return err + } + if err := hs.checkForResumption(); err != nil { + return err + } + if err := hs.pickCertificate(); err != nil { + return err + } + c.buffering = true + if err := hs.sendServerParameters(); err != nil { + return err + } + if err := hs.sendServerCertificate(); err != nil { + return err + } + if err := hs.sendServerFinished(); err != nil { + return err + } + // Note that at this point we could start sending application data without + // waiting for the client's second flight, but the application might not + // expect the lack of replay protection of the ClientHello parameters. + if _, err := c.flush(); err != nil { + return err + } + if err := hs.readClientCertificate(); err != nil { + return err + } + if err := hs.readClientFinished(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *serverHandshakeStateTLS13) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.encryptedExtensions = new(encryptedExtensionsMsg) + + // TLS 1.3 froze the ServerHello.legacy_version field, and uses + // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. + hs.hello.vers = VersionTLS12 + hs.hello.supportedVersion = c.vers + + if len(hs.clientHello.supportedVersions) == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") + } + + // Abort if the client is doing a fallback and landing lower than what we + // support. See RFC 7507, which however does not specify the interaction + // with supported_versions. The only difference is that with + // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] + // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, + // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to + // TLS 1.2, because a TLS 1.3 server would abort here. The situation before + // supported_versions was not better because there was just no way to do a + // TLS 1.4 handshake without risking the server selecting TLS 1.3. + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // Use c.vers instead of max(supported_versions) because an attacker + // could defeat this by adding an arbitrary high version otherwise. + if c.vers < c.config.maxSupportedVersion(roleServer) { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + if len(hs.clientHello.compressionMethods) != 1 || + hs.clientHello.compressionMethods[0] != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: TLS 1.3 client supports illegal compression methods") + } + + hs.hello.random = make([]byte, 32) + if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.compressionMethod = compressionNone + + if hs.suite == nil { + var preferenceList []uint16 + for _, suiteID := range c.config.CipherSuites { + for _, suite := range cipherSuitesTLS13 { + if suite.id == suiteID { + preferenceList = append(preferenceList, suiteID) + break + } + } + } + if len(preferenceList) == 0 { + preferenceList = defaultCipherSuitesTLS13 + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = defaultCipherSuitesTLS13NoAES + } + } + for _, suiteID := range preferenceList { + hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) + if hs.suite != nil { + break + } + } + } + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + hs.hello.cipherSuite = hs.suite.id + hs.transcript = hs.suite.hash.New() + + // Pick the ECDHE group in server preference order, but give priority to + // groups with a key share, to avoid a HelloRetryRequest round-trip. + var selectedGroup CurveID + var clientKeyShare *keyShare +GroupSelection: + for _, preferredGroup := range c.config.curvePreferences() { + for _, ks := range hs.clientHello.keyShares { + if ks.group == preferredGroup { + selectedGroup = ks.group + clientKeyShare = &ks + break GroupSelection + } + } + if selectedGroup != 0 { + continue + } + for _, group := range hs.clientHello.supportedCurves { + if group == preferredGroup { + selectedGroup = group + break + } + } + } + if selectedGroup == 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no ECDHE curve supported by both client and server") + } + if clientKeyShare == nil { + if err := hs.doHelloRetryRequest(selectedGroup); err != nil { + return err + } + clientKeyShare = &hs.clientHello.keyShares[0] + } + + if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} + hs.sharedKey = params.SharedKey(clientKeyShare.data) + if hs.sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid client key share") + } + + c.serverName = hs.clientHello.serverName + + if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil { + c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions) + } + + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + if err != nil { + hs.alpnNegotiationErr = err + } + hs.encryptedExtensions.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + + return nil +} + +func (hs *serverHandshakeStateTLS13) checkForResumption() error { + c := hs.c + + if c.config.SessionTicketsDisabled { + return nil + } + + modeOK := false + for _, mode := range hs.clientHello.pskModes { + if mode == pskModeDHE { + modeOK = true + break + } + } + if !modeOK { + return nil + } + + if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid or missing PSK binders") + } + if len(hs.clientHello.pskIdentities) == 0 { + return nil + } + + for i, identity := range hs.clientHello.pskIdentities { + if i >= maxClientPSKIdentities { + break + } + + plaintext, _ := c.decryptTicket(identity.label) + if plaintext == nil { + continue + } + sessionState := new(sessionStateTLS13) + if ok := sessionState.unmarshal(plaintext); !ok { + continue + } + + if hs.clientHello.earlyData { + if sessionState.maxEarlyData == 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: client sent unexpected early data") + } + + if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol && + c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 && + c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) { + hs.encryptedExtensions.earlyData = true + c.used0RTT = true + } + } + + createdAt := time.Unix(int64(sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + continue + } + + // We don't check the obfuscated ticket age because it's affected by + // clock skew and it's only a freshness signal useful for shrinking the + // window for replay attacks, which don't affect us as we don't do 0-RTT. + + pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) + if pskSuite == nil || pskSuite.hash != hs.suite.hash { + continue + } + + // PSK connections don't re-establish client certificates, but carry + // them over in the session ticket. Ensure the presence of client certs + // in the ticket is consistent with the configured requirements. + sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + continue + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + continue + } + + psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", + nil, hs.suite.hash.Size()) + hs.earlySecret = hs.suite.extract(psk, nil) + binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) + // Clone the transcript in case a HelloRetryRequest was recorded. + transcript := cloneHash(hs.transcript, hs.suite.hash) + if transcript == nil { + c.sendAlert(alertInternalError) + return errors.New("tls: internal error: failed to clone hash") + } + transcript.Write(hs.clientHello.marshalWithoutBinders()) + pskBinder := hs.suite.finishedHash(binderKey, transcript) + if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid PSK binder") + } + + c.didResume = true + if err := c.processCertsFromClient(sessionState.certificate); err != nil { + return err + } + + h := cloneHash(hs.transcript, hs.suite.hash) + h.Write(hs.clientHello.marshal()) + if hs.encryptedExtensions.earlyData { + clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h) + c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret) + if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + } + + hs.hello.selectedIdentityPresent = true + hs.hello.selectedIdentity = uint16(i) + hs.usingPSK = true + return nil + } + + return nil +} + +// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// interfaces implemented by standard library hashes to clone the state of in +// to a new instance of h. It returns nil if the operation fails. +func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { + // Recreate the interface to avoid importing encoding. + type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) + UnmarshalBinary(data []byte) error + } + marshaler, ok := in.(binaryMarshaler) + if !ok { + return nil + } + state, err := marshaler.MarshalBinary() + if err != nil { + return nil + } + out := h.New() + unmarshaler, ok := out.(binaryMarshaler) + if !ok { + return nil + } + if err := unmarshaler.UnmarshalBinary(state); err != nil { + return nil + } + return out +} + +func (hs *serverHandshakeStateTLS13) pickCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. + if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { + return c.sendAlert(alertMissingExtension) + } + + certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) + if err != nil { + // getCertificate returned a certificate that is unsupported or + // incompatible with the client's signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + hs.cert = certificate + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. See RFC 8446, Section 4.4.1. + hs.transcript.Write(hs.clientHello.marshal()) + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + + helloRetryRequest := &serverHelloMsg{ + vers: hs.hello.vers, + random: helloRetryRequestRandom, + sessionId: hs.hello.sessionId, + cipherSuite: hs.hello.cipherSuite, + compressionMethod: hs.hello.compressionMethod, + supportedVersion: hs.hello.supportedVersion, + selectedGroup: selectedGroup, + } + + hs.transcript.Write(helloRetryRequest.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientHello, msg) + } + + if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client sent invalid key share in second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client indicated early data in second ClientHello") + } + + if illegalClientHelloChange(clientHello, hs.clientHello) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client illegally modified second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client offered 0-RTT data in second ClientHello") + } + + hs.clientHello = clientHello + return nil +} + +// illegalClientHelloChange reports whether the two ClientHello messages are +// different, with the exception of the changes allowed before and after a +// HelloRetryRequest. See RFC 8446, Section 4.1.2. +func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { + if len(ch.supportedVersions) != len(ch1.supportedVersions) || + len(ch.cipherSuites) != len(ch1.cipherSuites) || + len(ch.supportedCurves) != len(ch1.supportedCurves) || + len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || + len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || + len(ch.alpnProtocols) != len(ch1.alpnProtocols) { + return true + } + for i := range ch.supportedVersions { + if ch.supportedVersions[i] != ch1.supportedVersions[i] { + return true + } + } + for i := range ch.cipherSuites { + if ch.cipherSuites[i] != ch1.cipherSuites[i] { + return true + } + } + for i := range ch.supportedCurves { + if ch.supportedCurves[i] != ch1.supportedCurves[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithms { + if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithmsCert { + if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { + return true + } + } + for i := range ch.alpnProtocols { + if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { + return true + } + } + return ch.vers != ch1.vers || + !bytes.Equal(ch.random, ch1.random) || + !bytes.Equal(ch.sessionId, ch1.sessionId) || + !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || + ch.serverName != ch1.serverName || + ch.ocspStapling != ch1.ocspStapling || + !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || + ch.ticketSupported != ch1.ticketSupported || + !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || + ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || + !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || + ch.scts != ch1.scts || + !bytes.Equal(ch.cookie, ch1.cookie) || + !bytes.Equal(ch.pskModes, ch1.pskModes) +} + +func (hs *serverHandshakeStateTLS13) sendServerParameters() error { + c := hs.c + + hs.transcript.Write(hs.clientHello.marshal()) + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + earlySecret := hs.earlySecret + if earlySecret == nil { + earlySecret = hs.suite.extract(nil, nil) + } + hs.handshakeSecret = hs.suite.extract(hs.sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret) + c.in.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if hs.alpnNegotiationErr != nil { + c.sendAlert(alertNoApplicationProtocol) + return hs.alpnNegotiationErr + } + if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil { + hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions) + } + + hs.transcript.Write(hs.encryptedExtensions.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.encryptedExtensions.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) requestClientCert() bool { + return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK +} + +func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + if hs.requestClientCert() { + // Request a client certificate + certReq := new(certificateRequestMsgTLS13) + certReq.ocspStapling = true + certReq.scts = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + + hs.transcript.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *hs.cert + certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + certVerifyMsg.signatureAlgorithm = hs.sigAlg + + sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + public := hs.cert.PrivateKey.(crypto.Signer).Public() + if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && + rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS + c.sendAlert(alertHandshakeFailure) + } else { + c.sendAlert(alertInternalError) + } + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) sendServerFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + // Derive secrets that take context through the server Finished. + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionApplication, hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + // If we did not request client certificates, at this point we can + // precompute the client finished and roll the transcript forward to send + // session tickets in our first flight. + if !hs.requestClientCert() { + if err := hs.sendSessionTickets(); err != nil { + return err + } + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { + if hs.c.config.SessionTicketsDisabled { + return false + } + + // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. + for _, pskMode := range hs.clientHello.pskModes { + if pskMode == pskModeDHE { + return true + } + } + return false +} + +func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { + c := hs.c + + hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + finishedMsg := &finishedMsg{ + verifyData: hs.clientFinished, + } + hs.transcript.Write(finishedMsg.marshal()) + + if !hs.shouldSendSessionTickets() { + return nil + } + + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + + // Don't send session tickets when the alternative record layer is set. + // Instead, save the resumption secret on the Conn. + // Session tickets can then be generated by calling Conn.GetSessionTicket(). + if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil { + return nil + } + + m, err := hs.c.getSessionTicketMsg(nil) + if err != nil { + return err + } + + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientCertificate() error { + c := hs.c + + if !hs.requestClientCert() { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + // If we requested a client certificate, then the client must send a + // certificate message. If it's empty, no CertificateVerify is sent. + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.transcript.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(certMsg.certificate); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if len(certMsg.certificate.Certificate) != 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + } + + // If we waited until the client certificates to send session tickets, we + // are ready to do it now. + if err := hs.sendSessionTickets(); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + if !hmac.Equal(hs.clientFinished, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid client finished hash") + } + + c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) + c.in.setTrafficSecret(hs.suite, hs.trafficSecret) + + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/key_agreement.go b/vendor/github.com/marten-seemann/qtls-go1-18/key_agreement.go new file mode 100644 index 00000000..453a8dcf --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/key_agreement.go @@ -0,0 +1,357 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/md5" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "errors" + "fmt" + "io" +) + +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error + generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) +} + +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + +// rsaKeyAgreement implements the standard TLS key agreement where the client +// encrypts the pre-master secret to the server's public key. +type rsaKeyAgreement struct{} + +func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + return nil, nil +} + +func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) + if ciphertextLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + ciphertext := ckx.ciphertext[2:] + + priv, ok := cert.PrivateKey.(crypto.Decrypter) + if !ok { + return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") + } + // Perform constant time RSA PKCS #1 v1.5 decryption + preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) + if err != nil { + return nil, err + } + // We don't check the version number in the premaster secret. For one, + // by checking it, we would leak information about the validity of the + // encrypted pre-master secret. Secondly, it provides only a small + // benefit against a downgrade attack and some implementations send the + // wrong version anyway. See the discussion at the end of section + // 7.4.7.1 of RFC 4346. + return preMasterSecret, nil +} + +func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + return errors.New("tls: unexpected ServerKeyExchange") +} + +func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + preMasterSecret := make([]byte, 48) + preMasterSecret[0] = byte(clientHello.vers >> 8) + preMasterSecret[1] = byte(clientHello.vers) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, nil, err + } + + rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") + } + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, len(encrypted)+2) + ckx.ciphertext[0] = byte(len(encrypted) >> 8) + ckx.ciphertext[1] = byte(len(encrypted)) + copy(ckx.ciphertext[2:], encrypted) + return preMasterSecret, ckx, nil +} + +// sha1Hash calculates a SHA1 hash over the given byte slices. +func sha1Hash(slices [][]byte) []byte { + hsha1 := sha1.New() + for _, slice := range slices { + hsha1.Write(slice) + } + return hsha1.Sum(nil) +} + +// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the +// concatenation of an MD5 and SHA1 hash. +func md5SHA1Hash(slices [][]byte) []byte { + md5sha1 := make([]byte, md5.Size+sha1.Size) + hmd5 := md5.New() + for _, slice := range slices { + hmd5.Write(slice) + } + copy(md5sha1, hmd5.Sum(nil)) + copy(md5sha1[md5.Size:], sha1Hash(slices)) + return md5sha1 +} + +// hashForServerKeyExchange hashes the given slices and returns their digest +// using the given hash function (for >= TLS 1.2) or using a default based on +// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't +// do pre-hashing, it returns the concatenation of the slices. +func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { + if sigType == signatureEd25519 { + var signed []byte + for _, slice := range slices { + signed = append(signed, slice...) + } + return signed + } + if version >= VersionTLS12 { + h := hashFunc.New() + for _, slice := range slices { + h.Write(slice) + } + digest := h.Sum(nil) + return digest + } + if sigType == signatureECDSA { + return sha1Hash(slices) + } + return md5SHA1Hash(slices) +} + +// ecdheKeyAgreement implements a TLS key agreement where the server +// generates an ephemeral EC public/private key pair and signs it. The +// pre-master secret is then calculated using ECDH. The signature may +// be ECDSA, Ed25519 or RSA. +type ecdheKeyAgreement struct { + version uint16 + isRSA bool + params ecdheParameters + + // ckx and preMasterSecret are generated in processServerKeyExchange + // and returned in generateClientKeyExchange. + ckx *clientKeyExchangeMsg + preMasterSecret []byte +} + +func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + var curveID CurveID + for _, c := range clientHello.supportedCurves { + if config.supportsCurve(c) { + curveID = c + break + } + } + + if curveID == 0 { + return nil, errors.New("tls: no supported elliptic curves offered") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, err + } + ka.params = params + + // See RFC 4492, Section 5.4. + ecdhePublic := params.PublicKey() + serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) + serverECDHEParams[0] = 3 // named curve + serverECDHEParams[1] = byte(curveID >> 8) + serverECDHEParams[2] = byte(curveID) + serverECDHEParams[3] = byte(len(ecdhePublic)) + copy(serverECDHEParams[4:], ecdhePublic) + + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) + } + + var signatureAlgorithm SignatureScheme + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) + if err != nil { + return nil, err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return nil, err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) + if err != nil { + return nil, err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) + + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := priv.Sign(config.rand(), signed, signOpts) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } + + skx := new(serverKeyExchangeMsg) + sigAndHashLen := 0 + if ka.version >= VersionTLS12 { + sigAndHashLen = 2 + } + skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) + copy(skx.key, serverECDHEParams) + k := skx.key[len(serverECDHEParams):] + if ka.version >= VersionTLS12 { + k[0] = byte(signatureAlgorithm >> 8) + k[1] = byte(signatureAlgorithm) + k = k[2:] + } + k[0] = byte(len(sig) >> 8) + k[1] = byte(len(sig)) + copy(k[2:], sig) + + return skx, nil +} + +func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { + return nil, errClientKeyExchange + } + + preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) + if preMasterSecret == nil { + return nil, errClientKeyExchange + } + + return preMasterSecret, nil +} + +func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 4 { + return errServerKeyExchange + } + if skx.key[0] != 3 { // named curve + return errors.New("tls: server selected unsupported curve") + } + curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + + publicLen := int(skx.key[3]) + if publicLen+4 > len(skx.key) { + return errServerKeyExchange + } + serverECDHEParams := skx.key[:4+publicLen] + publicKey := serverECDHEParams[4:] + + sig := skx.key[4+publicLen:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return errors.New("tls: server selected unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return err + } + ka.params = params + + ka.preMasterSecret = params.SharedKey(publicKey) + if ka.preMasterSecret == nil { + return errServerKeyExchange + } + + ourPublicKey := params.PublicKey() + ka.ckx = new(clientKeyExchangeMsg) + ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) + ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) + copy(ka.ckx.ciphertext[1:], ourPublicKey) + + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) + sig = sig[2:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) + if err != nil { + return err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return errServerKeyExchange + } + + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) + if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + return nil +} + +func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + if ka.ckx == nil { + return nil, nil, errors.New("tls: missing ServerKeyExchange message") + } + + return ka.preMasterSecret, ka.ckx, nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/key_schedule.go b/vendor/github.com/marten-seemann/qtls-go1-18/key_schedule.go new file mode 100644 index 00000000..da13904a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/key_schedule.go @@ -0,0 +1,199 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto/elliptic" + "crypto/hmac" + "errors" + "hash" + "io" + "math/big" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +// This file contains the functions necessary to compute the TLS 1.3 key +// schedule. See RFC 8446, Section 7. + +const ( + resumptionBinderLabel = "res binder" + clientHandshakeTrafficLabel = "c hs traffic" + serverHandshakeTrafficLabel = "s hs traffic" + clientApplicationTrafficLabel = "c ap traffic" + serverApplicationTrafficLabel = "s ap traffic" + exporterLabel = "exp master" + resumptionLabel = "res master" + trafficUpdateLabel = "traffic upd" +) + +// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("tls: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} + +// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { + if transcript == nil { + transcript = c.hash.New() + } + return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) +} + +// extract implements HKDF-Extract with the cipher suite hash. +func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { + if newSecret == nil { + newSecret = make([]byte, c.hash.Size()) + } + return hkdf.Extract(c.hash.New, newSecret, currentSecret) +} + +// nextTrafficSecret generates the next traffic secret, given the current one, +// according to RFC 8446, Section 7.2. +func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { + return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) +} + +// trafficKey generates traffic keys according to RFC 8446, Section 7.3. +func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { + key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) + iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) + return +} + +// finishedHash generates the Finished verify_data or PskBinderEntry according +// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey +// selection. +func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { + finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) + verifyData := hmac.New(c.hash.New, finishedKey) + verifyData.Write(transcript.Sum(nil)) + return verifyData.Sum(nil) +} + +// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to +// RFC 8446, Section 7.5. +func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { + expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) + return func(label string, context []byte, length int) ([]byte, error) { + secret := c.deriveSecret(expMasterSecret, label, nil) + h := c.hash.New() + h.Write(context) + return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil + } +} + +// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, +// according to RFC 8446, Section 4.2.8.2. +type ecdheParameters interface { + CurveID() CurveID + PublicKey() []byte + SharedKey(peerPublicKey []byte) []byte +} + +func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { + if curveID == X25519 { + privateKey := make([]byte, curve25519.ScalarSize) + if _, err := io.ReadFull(rand, privateKey); err != nil { + return nil, err + } + publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) + if err != nil { + return nil, err + } + return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil + } + + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, errors.New("tls: internal error: unsupported curve") + } + + p := &nistParameters{curveID: curveID} + var err error + p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) + if err != nil { + return nil, err + } + return p, nil +} + +func curveForCurveID(id CurveID) (elliptic.Curve, bool) { + switch id { + case CurveP256: + return elliptic.P256(), true + case CurveP384: + return elliptic.P384(), true + case CurveP521: + return elliptic.P521(), true + default: + return nil, false + } +} + +type nistParameters struct { + privateKey []byte + x, y *big.Int // public key + curveID CurveID +} + +func (p *nistParameters) CurveID() CurveID { + return p.curveID +} + +func (p *nistParameters) PublicKey() []byte { + curve, _ := curveForCurveID(p.curveID) + return elliptic.Marshal(curve, p.x, p.y) +} + +func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { + curve, _ := curveForCurveID(p.curveID) + // Unmarshal also checks whether the given point is on the curve. + x, y := elliptic.Unmarshal(curve, peerPublicKey) + if x == nil { + return nil + } + + xShared, _ := curve.ScalarMult(x, y, p.privateKey) + sharedKey := make([]byte, (curve.Params().BitSize+7)/8) + return xShared.FillBytes(sharedKey) +} + +type x25519Parameters struct { + privateKey []byte + publicKey []byte +} + +func (p *x25519Parameters) CurveID() CurveID { + return X25519 +} + +func (p *x25519Parameters) PublicKey() []byte { + return p.publicKey[:] +} + +func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { + sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) + if err != nil { + return nil + } + return sharedKey +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/prf.go b/vendor/github.com/marten-seemann/qtls-go1-18/prf.go new file mode 100644 index 00000000..9eb0221a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/prf.go @@ -0,0 +1,283 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" +) + +// Split a premaster secret in two as specified in RFC 4346, Section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return +} + +// pHash implements the P_hash function, as defined in RFC 4346, Section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} + +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. +func prf10(result, secret, label, seed []byte) { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + + for i, b := range result2 { + result[i] ^= b + } +} + +// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + pHash(result, secret, labelAndSeed, hashFunc) + } +} + +const ( + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. +) + +var masterSecretLabel = []byte("master secret") +var keyExpansionLabel = []byte("key expansion") +var clientFinishedLabel = []byte("client finished") +var serverFinishedLabel = []byte("server finished") + +func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { + switch version { + case VersionTLS10, VersionTLS11: + return prf10, crypto.Hash(0) + case VersionTLS12: + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 + default: + panic("unknown version") + } +} + +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + prf, _ := prfAndHashForVersion(version, suite) + return prf +} + +// masterFromPreMasterSecret generates the master secret from the pre-master +// secret. See RFC 5246, Section 8.1. +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { + seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) + return masterSecret +} + +// keysFromMasterSecret generates the connection keys from the master +// secret, given the lengths of the MAC key, cipher key and IV, as defined in +// RFC 2246, Section 6.3. +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { + seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) + seed = append(seed, serverRandom...) + seed = append(seed, clientRandom...) + + n := 2*macLen + 2*keyLen + 2*ivLen + keyMaterial := make([]byte, n) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) + clientMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + serverMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + clientKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + serverKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + clientIV = keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + serverIV = keyMaterial[:ivLen] + return +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + var buffer []byte + if version >= VersionTLS12 { + buffer = []byte{} + } + + prf, hash := prfAndHashForVersion(version, cipherSuite) + if hash != 0 { + return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} + } + + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +func (h *finishedHash) Write(msg []byte) (n int, err error) { + h.client.Write(msg) + h.server.Write(msg) + + if h.version < VersionTLS12 { + h.clientMD5.Write(msg) + h.serverMD5.Write(msg) + } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + + return len(msg), nil +} + +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + +// clientSum returns the contents of the verify_data member of a client's +// Finished message. +func (h finishedHash) clientSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) + return out +} + +// serverSum returns the contents of the verify_data member of a server's +// Finished message. +func (h finishedHash) serverSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) + return out +} + +// hashForClientCertificate returns the handshake messages so far, pre-hashed if +// necessary, suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { + if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { + panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") + } + + if sigType == signatureEd25519 { + return h.buffer + } + + if h.version >= VersionTLS12 { + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil) + } + + if sigType == signatureECDSA { + return h.server.Sum(nil) + } + + return h.Sum() +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil +} + +// noExportedKeyingMaterial is used as a value of +// ConnectionState.ekm when renegotiation is enabled and thus +// we wish to fail all key-material export requests. +func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") +} + +// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. +func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { + return func(label string, context []byte, length int) ([]byte, error) { + switch label { + case "client finished", "server finished", "master secret", "key expansion": + // These values are reserved and may not be used. + return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) + } + + seedLen := len(serverRandom) + len(clientRandom) + if context != nil { + seedLen += 2 + len(context) + } + seed := make([]byte, 0, seedLen) + + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + if context != nil { + if len(context) >= 1<<16 { + return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") + } + seed = append(seed, byte(len(context)>>8), byte(len(context))) + seed = append(seed, context...) + } + + keyMaterial := make([]byte, length) + prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) + return keyMaterial, nil + } +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/ticket.go b/vendor/github.com/marten-seemann/qtls-go1-18/ticket.go new file mode 100644 index 00000000..81e8a52e --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/ticket.go @@ -0,0 +1,274 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/binary" + "errors" + "io" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +// sessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type sessionState struct { + vers uint16 + cipherSuite uint16 + createdAt uint64 + masterSecret []byte // opaque master_secret<1..2^16-1>; + // struct { opaque certificate<1..2^24-1> } Certificate; + certificates [][]byte // Certificate certificate_list<0..2^24-1>; + + // usedOldKey is true if the ticket from which this session came from + // was encrypted with an older key and thus should be refreshed. + usedOldKey bool +} + +func (m *sessionState) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(m.vers) + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.masterSecret) + }) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for _, cert := range m.certificates { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + } + }) + return b.BytesOrPanic() +} + +func (m *sessionState) unmarshal(data []byte) bool { + *m = sessionState{usedOldKey: m.usedOldKey} + s := cryptobyte.String(data) + if ok := s.ReadUint16(&m.vers) && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint16LengthPrefixed(&s, &m.masterSecret) && + len(m.masterSecret) != 0; !ok { + return false + } + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + if !readUint24LengthPrefixed(&certList, &cert) { + return false + } + m.certificates = append(m.certificates, cert) + } + return s.Empty() +} + +// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first +// version (revision = 0) doesn't carry any of the information needed for 0-RTT +// validation and the nonce is always empty. +// version (revision = 1) carries the max_early_data_size sent in the ticket. +// version (revision = 2) carries the ALPN sent in the ticket. +type sessionStateTLS13 struct { + // uint8 version = 0x0304; + // uint8 revision = 2; + cipherSuite uint16 + createdAt uint64 + resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; + certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; + maxEarlyData uint32 + alpn string + + appData []byte +} + +func (m *sessionStateTLS13) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(VersionTLS13) + b.AddUint8(2) // revision + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.resumptionSecret) + }) + marshalCertificate(&b, m.certificate) + b.AddUint32(m.maxEarlyData) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpn)) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.appData) + }) + return b.BytesOrPanic() +} + +func (m *sessionStateTLS13) unmarshal(data []byte) bool { + *m = sessionStateTLS13{} + s := cryptobyte.String(data) + var version uint16 + var revision uint8 + var alpn []byte + ret := s.ReadUint16(&version) && + version == VersionTLS13 && + s.ReadUint8(&revision) && + revision == 2 && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint8LengthPrefixed(&s, &m.resumptionSecret) && + len(m.resumptionSecret) != 0 && + unmarshalCertificate(&s, &m.certificate) && + s.ReadUint32(&m.maxEarlyData) && + readUint8LengthPrefixed(&s, &alpn) && + readUint16LengthPrefixed(&s, &m.appData) && + s.Empty() + m.alpn = string(alpn) + return ret +} + +func (c *Conn) encryptTicket(state []byte) ([]byte, error) { + if len(c.ticketKeys) == 0 { + return nil, errors.New("tls: internal error: session ticket keys unavailable") + } + + encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + return nil, err + } + key := c.ticketKeys[0] + copy(keyName, key.keyName[:]) + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { + if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { + return nil, false + } + + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] + + keyIndex := -1 + for i, candidateKey := range c.ticketKeys { + if bytes.Equal(keyName, candidateKey.keyName[:]) { + keyIndex = i + break + } + } + if keyIndex == -1 { + return nil, false + } + key := &c.ticketKeys[keyIndex] + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + expected := mac.Sum(nil) + + if subtle.ConstantTimeCompare(macBytes, expected) != 1 { + return nil, false + } + + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, false + } + plaintext = make([]byte, len(ciphertext)) + cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) + + return plaintext, keyIndex > 0 +} + +func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) { + m := new(newSessionTicketMsgTLS13) + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionStateTLS13{ + cipherSuite: c.cipherSuite, + createdAt: uint64(c.config.time().Unix()), + resumptionSecret: c.resumptionSecret, + certificate: Certificate{ + Certificate: certsFromClient, + OCSPStaple: c.ocspResponse, + SignedCertificateTimestamps: c.scts, + }, + appData: appData, + alpn: c.clientProtocol, + } + if c.extraConfig != nil { + state.maxEarlyData = c.extraConfig.MaxEarlyData + } + var err error + m.label, err = c.encryptTicket(state.marshal()) + if err != nil { + return nil, err + } + m.lifetime = uint32(maxSessionTicketLifetime / time.Second) + + // ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1 + // The value is not stored anywhere; we never need to check the ticket age + // because 0-RTT is not supported. + ageAdd := make([]byte, 4) + _, err = c.config.rand().Read(ageAdd) + if err != nil { + return nil, err + } + m.ageAdd = binary.LittleEndian.Uint32(ageAdd) + + // ticket_nonce, which must be unique per connection, is always left at + // zero because we only ever send one ticket per connection. + + if c.extraConfig != nil { + m.maxEarlyData = c.extraConfig.MaxEarlyData + } + return m, nil +} + +// GetSessionTicket generates a new session ticket. +// It should only be called after the handshake completes. +// It can only be used for servers, and only if the alternative record layer is set. +// The ticket may be nil if config.SessionTicketsDisabled is set, +// or if the client isn't able to receive session tickets. +func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { + if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.") + } + if c.config.SessionTicketsDisabled { + return nil, nil + } + + m, err := c.getSessionTicketMsg(appData) + if err != nil { + return nil, err + } + return m.marshal(), nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/tls.go b/vendor/github.com/marten-seemann/qtls-go1-18/tls.go new file mode 100644 index 00000000..42207c23 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/tls.go @@ -0,0 +1,362 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// package qtls partially implements TLS 1.2, as specified in RFC 5246, +// and TLS 1.3, as specified in RFC 8446. +package qtls + +// BUG(agl): The crypto/tls package only implements some countermeasures +// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 +// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net" + "os" + "strings" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + c := &Conn{ + conn: conn, + config: fromConfig(config), + extraConfig: extraConfig, + } + c.handshakeFn = c.serverHandshake + return c +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + c := &Conn{ + conn: conn, + config: fromConfig(config), + extraConfig: extraConfig, + isClient: true, + } + c.handshakeFn = c.clientHandshake + return c +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type listener struct { + net.Listener + config *Config + extraConfig *ExtraConfig +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection is of type *Conn. +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return Server(c, l.config, l.extraConfig), nil +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener { + l := new(listener) + l.Listener = inner + l.config = config + l.extraConfig = extraConfig + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) { + if config == nil || len(config.Certificates) == 0 && + config.GetCertificate == nil && config.GetConfigForClient == nil { + return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config, extraConfig), nil +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +// +// DialWithDialer uses context.Background internally; to specify the context, +// use Dialer.DialContext with NetDialer set to the desired dialer. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + return dial(context.Background(), dialer, network, addr, config, extraConfig) +} + +func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + if netDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) + defer cancel() + } + + if !netDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) + defer cancel() + } + + rawConn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = defaultConfig() + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := config.Clone() + c.ServerName = hostname + config = c + } + + conn := Client(rawConn, config, extraConfig) + if err := conn.HandshakeContext(ctx); err != nil { + rawConn.Close() + return nil, err + } + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig) +} + +// Dialer dials TLS connections given a configuration and a Dialer for the +// underlying connection. +type Dialer struct { + // NetDialer is the optional dialer to use for the TLS connections' + // underlying TCP connections. + // A nil NetDialer is equivalent to the net.Dialer zero value. + NetDialer *net.Dialer + + // Config is the TLS configuration to use for new connections. + // A nil configuration is equivalent to the zero + // configuration; see the documentation of Config for the + // defaults. + Config *Config + + ExtraConfig *ExtraConfig +} + +// Dial connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The returned Conn, if any, will always be of type *Conn. +// +// Dial uses context.Background internally; to specify the context, +// use DialContext. +func (d *Dialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} + +func (d *Dialer) netDialer() *net.Dialer { + if d.NetDialer != nil { + return d.NetDialer + } + return new(net.Dialer) +} + +// DialContext connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig) + if err != nil { + // Don't return c (a typed nil) in an interface. + return nil, err + } + return c, nil +} + +// LoadX509KeyPair reads and parses a public/private key pair from a pair +// of files. The files must contain PEM encoded data. The certificate file +// may contain intermediate certificates following the leaf certificate to +// form a certificate chain. On successful return, Certificate.Leaf will +// be nil because the parsed form of the certificate is not retained. +func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { + certPEMBlock, err := os.ReadFile(certFile) + if err != nil { + return Certificate{}, err + } + keyPEMBlock, err := os.ReadFile(keyFile) + if err != nil { + return Certificate{}, err + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} + +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. On successful return, Certificate.Leaf will be nil because +// the parsed form of the certificate is not retained. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + fail := func(err error) (Certificate, error) { return Certificate{}, err } + + var cert Certificate + var skippedBlockTypes []string + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } else { + skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) + } + } + + if len(cert.Certificate) == 0 { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in certificate input")) + } + if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { + return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) + } + return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + + skippedBlockTypes = skippedBlockTypes[:0] + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) + if keyDERBlock == nil { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in key input")) + } + if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { + return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) + } + return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) + } + + // We don't need to parse the public key for TLS, but we so do anyway + // to check that it looks sane and matches the private key. + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fail(err) + } + + cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) + if err != nil { + return fail(err) + } + + switch pub := x509Cert.PublicKey.(type) { + case *rsa.PublicKey: + priv, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.N.Cmp(priv.N) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case *ecdsa.PublicKey: + priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case ed25519.PublicKey: + priv, ok := cert.PrivateKey.(ed25519.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { + return fail(errors.New("tls: private key does not match public key")) + } + default: + return fail(errors.New("tls: unknown public key algorithm")) + } + + return cert, nil +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("tls: failed to parse private key") +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/unsafe.go b/vendor/github.com/marten-seemann/qtls-go1-18/unsafe.go new file mode 100644 index 00000000..55fa01b3 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-18/unsafe.go @@ -0,0 +1,96 @@ +package qtls + +import ( + "crypto/tls" + "reflect" + "unsafe" +) + +func init() { + if !structsEqual(&tls.ConnectionState{}, &connectionState{}) { + panic("qtls.ConnectionState doesn't match") + } + if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) { + panic("qtls.ClientSessionState doesn't match") + } + if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) { + panic("qtls.CertificateRequestInfo doesn't match") + } + if !structsEqual(&tls.Config{}, &config{}) { + panic("qtls.Config doesn't match") + } + if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { + panic("qtls.ClientHelloInfo doesn't match") + } +} + +func toConnectionState(c connectionState) ConnectionState { + return *(*ConnectionState)(unsafe.Pointer(&c)) +} + +func toClientSessionState(s *clientSessionState) *ClientSessionState { + return (*ClientSessionState)(unsafe.Pointer(s)) +} + +func fromClientSessionState(s *ClientSessionState) *clientSessionState { + return (*clientSessionState)(unsafe.Pointer(s)) +} + +func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo { + return (*CertificateRequestInfo)(unsafe.Pointer(i)) +} + +func toConfig(c *config) *Config { + return (*Config)(unsafe.Pointer(c)) +} + +func fromConfig(c *Config) *config { + return (*config)(unsafe.Pointer(c)) +} + +func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo { + return (*ClientHelloInfo)(unsafe.Pointer(chi)) +} + +func structsEqual(a, b interface{}) bool { + return compare(reflect.ValueOf(a), reflect.ValueOf(b)) +} + +func compare(a, b reflect.Value) bool { + sa := a.Elem() + sb := b.Elem() + if sa.NumField() != sb.NumField() { + return false + } + for i := 0; i < sa.NumField(); i++ { + fa := sa.Type().Field(i) + fb := sb.Type().Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + if fa.Type.Kind() != fb.Type.Kind() { + return false + } + if fa.Type.Kind() == reflect.Slice { + if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) { + return false + } + continue + } + return false + } + } + return true +} + +func compareStruct(a, b reflect.Type) bool { + if a.NumField() != b.NumField() { + return false + } + for i := 0; i < a.NumField(); i++ { + fa := a.Field(i) + fb := b.Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + return false + } + } + return true +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/LICENSE b/vendor/github.com/marten-seemann/qtls-go1-19/LICENSE new file mode 100644 index 00000000..6a66aea5 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/README.md b/vendor/github.com/marten-seemann/qtls-go1-19/README.md new file mode 100644 index 00000000..3e902212 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/README.md @@ -0,0 +1,6 @@ +# qtls + +[![Go Reference](https://pkg.go.dev/badge/github.com/marten-seemann/qtls-go1-17.svg)](https://pkg.go.dev/github.com/marten-seemann/qtls-go1-17) +[![.github/workflows/go-test.yml](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml/badge.svg)](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml) + +This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/alert.go b/vendor/github.com/marten-seemann/qtls-go1-19/alert.go new file mode 100644 index 00000000..3feac79b --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/alert.go @@ -0,0 +1,102 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import "strconv" + +type alert uint8 + +// Alert is a TLS alert +type Alert = alert + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertExportRestriction alert = 60 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertInappropriateFallback alert = 86 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 + alertMissingExtension alert = 109 + alertUnsupportedExtension alert = 110 + alertCertificateUnobtainable alert = 111 + alertUnrecognizedName alert = 112 + alertBadCertificateStatusResponse alert = 113 + alertBadCertificateHashValue alert = 114 + alertUnknownPSKIdentity alert = 115 + alertCertificateRequired alert = 116 + alertNoApplicationProtocol alert = 120 +) + +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertExportRestriction: "export restriction", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertInappropriateFallback: "inappropriate fallback", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", + alertMissingExtension: "missing extension", + alertUnsupportedExtension: "unsupported extension", + alertCertificateUnobtainable: "certificate unobtainable", + alertUnrecognizedName: "unrecognized name", + alertBadCertificateStatusResponse: "bad certificate status response", + alertBadCertificateHashValue: "bad certificate hash value", + alertUnknownPSKIdentity: "unknown PSK identity", + alertCertificateRequired: "certificate required", + alertNoApplicationProtocol: "no application protocol", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return "tls: " + s + } + return "tls: alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e alert) Error() string { + return e.String() +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/auth.go b/vendor/github.com/marten-seemann/qtls-go1-19/auth.go new file mode 100644 index 00000000..effc9ace --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/auth.go @@ -0,0 +1,293 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "errors" + "fmt" + "hash" + "io" +) + +// verifyHandshakeSignature verifies a signature against pre-hashed +// (if required) handshake contents. +func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) + } + if !ecdsa.VerifyASN1(pubKey, signed, sig) { + return errors.New("ECDSA verification failure") + } + case signatureEd25519: + pubKey, ok := pubkey.(ed25519.PublicKey) + if !ok { + return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) + } + if !ed25519.Verify(pubKey, signed, sig) { + return errors.New("Ed25519 verification failure") + } + case signaturePKCS1v15: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { + return err + } + case signatureRSAPSS: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { + return err + } + default: + return errors.New("internal error: unknown signature type") + } + return nil +} + +const ( + serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" + clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" +) + +var signaturePadding = []byte{ + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, + 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, +} + +// signedMessage returns the pre-hashed (if necessary) message to be signed by +// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. +func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { + if sigHash == directSigning { + b := &bytes.Buffer{} + b.Write(signaturePadding) + io.WriteString(b, context) + b.Write(transcript.Sum(nil)) + return b.Bytes() + } + h := sigHash.New() + h.Write(signaturePadding) + io.WriteString(h, context) + h.Write(transcript.Sum(nil)) + return h.Sum(nil) +} + +// typeAndHashFromSignatureScheme returns the corresponding signature type and +// crypto.Hash for a given TLS SignatureScheme. +func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { + switch signatureAlgorithm { + case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: + sigType = signaturePKCS1v15 + case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: + sigType = signatureRSAPSS + case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: + sigType = signatureECDSA + case Ed25519: + sigType = signatureEd25519 + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + switch signatureAlgorithm { + case PKCS1WithSHA1, ECDSAWithSHA1: + hash = crypto.SHA1 + case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: + hash = crypto.SHA256 + case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: + hash = crypto.SHA384 + case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: + hash = crypto.SHA512 + case Ed25519: + hash = directSigning + default: + return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) + } + return sigType, hash, nil +} + +// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for +// a given public key used with TLS 1.0 and 1.1, before the introduction of +// signature algorithm negotiation. +func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { + switch pub.(type) { + case *rsa.PublicKey: + return signaturePKCS1v15, crypto.MD5SHA1, nil + case *ecdsa.PublicKey: + return signatureECDSA, crypto.SHA1, nil + case ed25519.PublicKey: + // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, + // but it requires holding on to a handshake transcript to do a + // full signature, and not even OpenSSL bothers with the + // complexity, so we can't even test it properly. + return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") + default: + return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) + } +} + +var rsaSignatureSchemes = []struct { + scheme SignatureScheme + minModulusBytes int + maxVersion uint16 +}{ + // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires + // emLen >= hLen + sLen + 2 + {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, + // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires + // emLen >= len(prefix) + hLen + 11 + // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. + {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, + {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, + {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, + {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, +} + +// signatureSchemesForCertificate returns the list of supported SignatureSchemes +// for a given certificate, based on the public key and the protocol version, +// and optionally filtered by its explicit SupportedSignatureAlgorithms. +// +// This function must be kept in sync with supportedSignatureAlgorithms. +// FIPS filtering is applied in the caller, selectSignatureScheme. +func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil + } + + var sigAlgs []SignatureScheme + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + if version != VersionTLS13 { + // In TLS 1.2 and earlier, ECDSA algorithms are not + // constrained to a single curve. + sigAlgs = []SignatureScheme{ + ECDSAWithP256AndSHA256, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + ECDSAWithSHA1, + } + break + } + switch pub.Curve { + case elliptic.P256(): + sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} + case elliptic.P384(): + sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} + case elliptic.P521(): + sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} + default: + return nil + } + case *rsa.PublicKey: + size := pub.Size() + sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) + for _, candidate := range rsaSignatureSchemes { + if size >= candidate.minModulusBytes && version <= candidate.maxVersion { + sigAlgs = append(sigAlgs, candidate.scheme) + } + } + case ed25519.PublicKey: + sigAlgs = []SignatureScheme{Ed25519} + default: + return nil + } + + if cert.SupportedSignatureAlgorithms != nil { + var filteredSigAlgs []SignatureScheme + for _, sigAlg := range sigAlgs { + if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { + filteredSigAlgs = append(filteredSigAlgs, sigAlg) + } + } + return filteredSigAlgs + } + return sigAlgs +} + +// selectSignatureScheme picks a SignatureScheme from the peer's preference list +// that works with the selected certificate. It's only called for protocol +// versions that support signature algorithms, so TLS 1.2 and 1.3. +func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { + supportedAlgs := signatureSchemesForCertificate(vers, c) + if len(supportedAlgs) == 0 { + return 0, unsupportedCertificateError(c) + } + if len(peerAlgs) == 0 && vers == VersionTLS12 { + // For TLS 1.2, if the client didn't send signature_algorithms then we + // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. + peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} + } + // Pick signature scheme in the peer's preference order, as our + // preference order is not configurable. + for _, preferredAlg := range peerAlgs { + if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) { + continue + } + if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { + return preferredAlg, nil + } + } + return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") +} + +// unsupportedCertificateError returns a helpful error for certificates with +// an unsupported private key. +func unsupportedCertificateError(cert *Certificate) error { + switch cert.PrivateKey.(type) { + case rsa.PrivateKey, ecdsa.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", + cert.PrivateKey, cert.PrivateKey) + case *ed25519.PrivateKey: + return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") + } + + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", + cert.PrivateKey) + } + + switch pub := signer.Public().(type) { + case *ecdsa.PublicKey: + switch pub.Curve { + case elliptic.P256(): + case elliptic.P384(): + case elliptic.P521(): + default: + return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) + } + case *rsa.PublicKey: + return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") + case ed25519.PublicKey: + default: + return fmt.Errorf("tls: unsupported certificate key (%T)", pub) + } + + if cert.SupportedSignatureAlgorithms != nil { + return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") + } + + return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/cipher_suites.go b/vendor/github.com/marten-seemann/qtls-go1-19/cipher_suites.go new file mode 100644 index 00000000..56dd4543 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/cipher_suites.go @@ -0,0 +1,693 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/hmac" + "crypto/rc4" + "crypto/sha1" + "crypto/sha256" + "fmt" + "hash" + + "golang.org/x/crypto/chacha20poly1305" +) + +// CipherSuite is a TLS cipher suite. Note that most functions in this package +// accept and expose cipher suite IDs instead of this type. +type CipherSuite struct { + ID uint16 + Name string + + // Supported versions is the list of TLS protocol versions that can + // negotiate this cipher suite. + SupportedVersions []uint16 + + // Insecure is true if the cipher suite has known security issues + // due to its primitives, design, or implementation. + Insecure bool +} + +var ( + supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} + supportedOnlyTLS12 = []uint16{VersionTLS12} + supportedOnlyTLS13 = []uint16{VersionTLS13} +) + +// CipherSuites returns a list of cipher suites currently implemented by this +// package, excluding those with security issues, which are returned by +// InsecureCipherSuites. +// +// The list is sorted by ID. Note that the default cipher suites selected by +// this package might depend on logic that can't be captured by a static list, +// and might not match those returned by this function. +func CipherSuites() []*CipherSuite { + return []*CipherSuite{ + {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + + {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, + {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, + {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, + + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, + } +} + +// InsecureCipherSuites returns a list of cipher suites currently implemented by +// this package and which have security issues. +// +// Most applications should not use the cipher suites in this list, and should +// only use those returned by CipherSuites. +func InsecureCipherSuites() []*CipherSuite { + // This list includes RC4, CBC_SHA256, and 3DES cipher suites. See + // cipherSuitesPreferenceOrder for details. + return []*CipherSuite{ + {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, + } +} + +// CipherSuiteName returns the standard name for the passed cipher suite ID +// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation +// of the ID value if the cipher suite is not implemented by this package. +func CipherSuiteName(id uint16) string { + for _, c := range CipherSuites() { + if c.ID == id { + return c.Name + } + } + for _, c := range InsecureCipherSuites() { + if c.ID == id { + return c.Name + } + } + return fmt.Sprintf("0x%04X", id) +} + +const ( + // suiteECDHE indicates that the cipher suite involves elliptic curve + // Diffie-Hellman. This means that it should only be selected when the + // client indicates that it supports ECC with a curve and point format + // that we're happy with. + suiteECDHE = 1 << iota + // suiteECSign indicates that the cipher suite involves an ECDSA or + // EdDSA signature and therefore may only be selected when the server's + // certificate is ECDSA or EdDSA. If this is not set then the cipher suite + // is RSA based. + suiteECSign + // suiteTLS12 indicates that the cipher suite should only be advertised + // and accepted when using TLS 1.2. + suiteTLS12 + // suiteSHA384 indicates that the cipher suite uses SHA384 as the + // handshake hash. + suiteSHA384 +) + +// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange +// mechanism, as well as the cipher+MAC pair or the AEAD. +type cipherSuite struct { + id uint16 + // the lengths, in bytes, of the key material needed for each component. + keyLen int + macLen int + ivLen int + ka func(version uint16) keyAgreement + // flags is a bitmask of the suite* values, above. + flags int + cipher func(key, iv []byte, isRead bool) any + mac func(key []byte) hash.Hash + aead func(key, fixedNonce []byte) aead +} + +var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter. + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil}, +} + +// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which +// is also in supportedIDs and passes the ok filter. +func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { + for _, id := range ids { + candidate := cipherSuiteByID(id) + if candidate == nil || !ok(candidate) { + continue + } + + for _, suppID := range supportedIDs { + if id == suppID { + return candidate + } + } + } + return nil +} + +// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash +// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. +type cipherSuiteTLS13 struct { + id uint16 + keyLen int + aead func(key, fixedNonce []byte) aead + hash crypto.Hash +} + +type CipherSuiteTLS13 struct { + ID uint16 + KeyLen int + Hash crypto.Hash + AEAD func(key, fixedNonce []byte) cipher.AEAD +} + +func (c *CipherSuiteTLS13) IVLen() int { + return aeadNonceLength +} + +var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map. + {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, + {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, + {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, +} + +// cipherSuitesPreferenceOrder is the order in which we'll select (on the +// server) or advertise (on the client) TLS 1.0–1.2 cipher suites. +// +// Cipher suites are filtered but not reordered based on the application and +// peer's preferences, meaning we'll never select a suite lower in this list if +// any higher one is available. This makes it more defensible to keep weaker +// cipher suites enabled, especially on the server side where we get the last +// word, since there are no known downgrade attacks on cipher suites selection. +// +// The list is sorted by applying the following priority rules, stopping at the +// first (most important) applicable one: +// +// - Anything else comes before RC4 +// +// RC4 has practically exploitable biases. See https://www.rc4nomore.com. +// +// - Anything else comes before CBC_SHA256 +// +// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 +// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. +// +// - Anything else comes before 3DES +// +// 3DES has 64-bit blocks, which makes it fundamentally susceptible to +// birthday attacks. See https://sweet32.info. +// +// - ECDHE comes before anything else +// +// Once we got the broken stuff out of the way, the most important +// property a cipher suite can have is forward secrecy. We don't +// implement FFDHE, so that means ECDHE. +// +// - AEADs come before CBC ciphers +// +// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites +// are fundamentally fragile, and suffered from an endless sequence of +// padding oracle attacks. See https://eprint.iacr.org/2015/1129, +// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and +// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. +// +// - AES comes before ChaCha20 +// +// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster +// than ChaCha20Poly1305. +// +// When AES hardware is not available, AES-128-GCM is one or more of: much +// slower, way more complex, and less safe (because not constant time) +// than ChaCha20Poly1305. +// +// We use this list if we think both peers have AES hardware, and +// cipherSuitesPreferenceOrderNoAES otherwise. +// +// - AES-128 comes before AES-256 +// +// The only potential advantages of AES-256 are better multi-target +// margins, and hypothetical post-quantum properties. Neither apply to +// TLS, and AES-256 is slower due to its four extra rounds (which don't +// contribute to the advantages above). +// +// - ECDSA comes before RSA +// +// The relative order of ECDSA and RSA cipher suites doesn't matter, +// as they depend on the certificate. Pick one to get a stable order. +var cipherSuitesPreferenceOrder = []uint16{ + // AEADs w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + + // CBC w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + + // AEADs w/o ECDHE + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + + // CBC w/o ECDHE + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + + // 3DES + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_RSA_WITH_3DES_EDE_CBC_SHA, + + // CBC_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + + // RC4 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +var cipherSuitesPreferenceOrderNoAES = []uint16{ + // ChaCha20Poly1305 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + + // AES-GCM w/ ECDHE + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + + // The rest of cipherSuitesPreferenceOrder. + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + TLS_RSA_WITH_AES_128_CBC_SHA, + TLS_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +// disabledCipherSuites are not used unless explicitly listed in +// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder. +var disabledCipherSuites = []uint16{ + // CBC_SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, + + // RC4 + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_RSA_WITH_RC4_128_SHA, +} + +var ( + defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) + defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen] +) + +// defaultCipherSuitesTLS13 is also the preference order, since there are no +// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as +// cipherSuitesPreferenceOrder applies. +var defaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_CHACHA20_POLY1305_SHA256, +} + +var defaultCipherSuitesTLS13NoAES = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, +} + +var aesgcmCiphers = map[uint16]bool{ + // TLS 1.2 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, + // TLS 1.3 + TLS_AES_128_GCM_SHA256: true, + TLS_AES_256_GCM_SHA384: true, +} + +var nonAESGCMAEADCiphers = map[uint16]bool{ + // TLS 1.2 + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, + // TLS 1.3 + TLS_CHACHA20_POLY1305_SHA256: true, +} + +// aesgcmPreferred returns whether the first known cipher in the preference list +// is an AES-GCM cipher, implying the peer has hardware support for it. +func aesgcmPreferred(ciphers []uint16) bool { + for _, cID := range ciphers { + if c := cipherSuiteByID(cID); c != nil { + return aesgcmCiphers[cID] + } + if c := cipherSuiteTLS13ByID(cID); c != nil { + return aesgcmCiphers[cID] + } + } + return false +} + +func cipherRC4(key, iv []byte, isRead bool) any { + cipher, _ := rc4.NewCipher(key) + return cipher +} + +func cipher3DES(key, iv []byte, isRead bool) any { + block, _ := des.NewTripleDESCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +func cipherAES(key, iv []byte, isRead bool) any { + block, _ := aes.NewCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +// macSHA1 returns a SHA-1 based constant time MAC. +func macSHA1(key []byte) hash.Hash { + h := sha1.New + h = newConstantTimeHash(h) + return hmac.New(h, key) +} + +// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and +// is currently only used in disabled-by-default cipher suites. +func macSHA256(key []byte) hash.Hash { + return hmac.New(sha256.New, key) +} + +type aead interface { + cipher.AEAD + + // explicitNonceLen returns the number of bytes of explicit nonce + // included in each record. This is eight for older AEADs and + // zero for modern ones. + explicitNonceLen() int +} + +const ( + aeadNonceLength = 12 + noncePrefixLength = 4 +) + +// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +// each call. +type prefixNonceAEAD struct { + // nonce contains the fixed part of the nonce in the first four bytes. + nonce [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } +func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } + +func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + copy(f.nonce[4:], nonce) + return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) +} + +func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + copy(f.nonce[4:], nonce) + return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) +} + +// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// before each call. +type xorNonceAEAD struct { + nonceMask [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result, err +} + +func aeadAESGCM(key, noncePrefix []byte) aead { + if len(noncePrefix) != noncePrefixLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + var aead cipher.AEAD + aead, err = cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &prefixNonceAEAD{aead: aead} + copy(ret.nonce[:], noncePrefix) + return ret +} + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return aeadAESGCMTLS13(key, fixedNonce) +} + +func aeadAESGCMTLS13(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +func aeadChaCha20Poly1305(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + +// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. +func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { + h.Reset() + h.Write(seq) + h.Write(header) + h.Write(data) + res := h.Sum(out) + if extra != nil { + h.Write(extra) + } + return res +} + +func rsaKA(version uint16) keyAgreement { + return rsaKeyAgreement{} +} + +func ecdheECDSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: false, + version: version, + } +} + +func ecdheRSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + isRSA: true, + version: version, + } +} + +// mutualCipherSuite returns a cipherSuite given a list of supported +// ciphersuites and the id requested by the peer. +func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { + for _, id := range have { + if id == want { + return cipherSuiteByID(id) + } + } + return nil +} + +func cipherSuiteByID(id uint16) *cipherSuite { + for _, cipherSuite := range cipherSuites { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { + for _, id := range have { + if id == want { + return cipherSuiteTLS13ByID(id) + } + } + return nil +} + +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { + for _, cipherSuite := range cipherSuitesTLS13 { + if cipherSuite.id == id { + return cipherSuite + } + } + return nil +} + +// A list of cipher suite IDs that are, or have been, implemented by this +// package. +// +// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml +const ( + // TLS 1.0 - 1.2 cipher suites. + TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 + + // TLS 1.3 cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = 0x1301 + TLS_AES_256_GCM_SHA384 uint16 = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 + + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator + // that the client is doing version fallback. See RFC 7507. + TLS_FALLBACK_SCSV uint16 = 0x5600 + + // Legacy names for the corresponding cipher suites with the correct _SHA256 + // suffix, retained for backward compatibility. + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/common.go b/vendor/github.com/marten-seemann/qtls-go1-19/common.go new file mode 100644 index 00000000..6670ce2f --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/common.go @@ -0,0 +1,1512 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "container/list" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "strings" + "sync" + "time" +) + +const ( + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 + VersionTLS13 = 0x0304 + + // Deprecated: SSLv3 is cryptographically broken, and is no longer + // supported by this package. See golang.org/issue/32716. + VersionSSL30 = 0x0300 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxUselessRecords = 16 // maximum number of consecutive non-advancing records +) + +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 + typeNextProtocol uint8 = 67 // Not IANA assigned + typeMessageHash uint8 = 254 // synthetic message +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +type Extension struct { + Type uint16 + Data []byte +} + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionRenegotiationInfo uint16 = 0xff01 +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +type EncryptionLevel uint8 + +const ( + EncryptionHandshake EncryptionLevel = iota + Encryption0RTT + EncryptionApplication +) + +// CurveID is a tls.CurveID +type CurveID = tls.CurveID + +const ( + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + X25519 CurveID = 29 +) + +// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. +type keyShare struct { + group CurveID + data []byte +} + +// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. +const ( + pskModePlain uint8 = 0 + pskModeDHE uint8 = 1 +) + +// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved +// session. See RFC 8446, Section 4.2.11. +type pskIdentity struct { + label []byte + obfuscatedTicketAge uint32 +} + +// TLS Elliptic Curve Point Formats +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + certTypeRSASign = 1 + certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. +) + +// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with +// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. +const ( + signaturePKCS1v15 uint8 = iota + 225 + signatureRSAPSS + signatureECDSA + signatureEd25519 +) + +// directSigning is a standard Hash value that signals that no pre-hashing +// should be performed, and that the input should be signed directly. It is the +// hash function associated with the Ed25519 signature scheme. +var directSigning crypto.Hash = 0 + +// defaultSupportedSignatureAlgorithms contains the signature and hash algorithms that +// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ +// CertificateRequest. The two fields are merged to match with TLS 1.3. +// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. +var defaultSupportedSignatureAlgorithms = []SignatureScheme{ + PSSWithSHA256, + ECDSAWithP256AndSHA256, + Ed25519, + PSSWithSHA384, + PSSWithSHA512, + PKCS1WithSHA256, + PKCS1WithSHA384, + PKCS1WithSHA512, + ECDSAWithP384AndSHA384, + ECDSAWithP521AndSHA512, + PKCS1WithSHA1, + ECDSAWithSHA1, +} + +// helloRetryRequestRandom is set as the Random value of a ServerHello +// to signal that the message is actually a HelloRetryRequest. +var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, + 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, + 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, +} + +const ( + // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server + // random as a downgrade protection if the server would be capable of + // negotiating a higher version. See RFC 8446, Section 4.1.3. + downgradeCanaryTLS12 = "DOWNGRD\x01" + downgradeCanaryTLS11 = "DOWNGRD\x00" +) + +// testingOnlyForceDowngradeCanary is set in tests to force the server side to +// include downgrade canaries even if it's using its highers supported version. +var testingOnlyForceDowngradeCanary bool + +type ConnectionState = tls.ConnectionState + +// ConnectionState records basic TLS details about the connection. +type connectionState struct { + // Version is the TLS version used by the connection (e.g. VersionTLS12). + Version uint16 + + // HandshakeComplete is true if the handshake has concluded. + HandshakeComplete bool + + // DidResume is true if this connection was successfully resumed from a + // previous session with a session ticket or similar mechanism. + DidResume bool + + // CipherSuite is the cipher suite negotiated for the connection (e.g. + // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). + CipherSuite uint16 + + // NegotiatedProtocol is the application protocol negotiated with ALPN. + NegotiatedProtocol string + + // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. + // + // Deprecated: this value is always true. + NegotiatedProtocolIsMutual bool + + // ServerName is the value of the Server Name Indication extension sent by + // the client. It's available both on the server and on the client side. + ServerName string + + // PeerCertificates are the parsed certificates sent by the peer, in the + // order in which they were sent. The first element is the leaf certificate + // that the connection is verified against. + // + // On the client side, it can't be empty. On the server side, it can be + // empty if Config.ClientAuth is not RequireAnyClientCert or + // RequireAndVerifyClientCert. + PeerCertificates []*x509.Certificate + + // VerifiedChains is a list of one or more chains where the first element is + // PeerCertificates[0] and the last element is from Config.RootCAs (on the + // client side) or Config.ClientCAs (on the server side). + // + // On the client side, it's set if Config.InsecureSkipVerify is false. On + // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven + // (and the peer provided a certificate) or RequireAndVerifyClientCert. + VerifiedChains [][]*x509.Certificate + + // SignedCertificateTimestamps is a list of SCTs provided by the peer + // through the TLS handshake for the leaf certificate, if any. + SignedCertificateTimestamps [][]byte + + // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) + // response provided by the peer for the leaf certificate, if any. + OCSPResponse []byte + + // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, + // Section 3). This value will be nil for TLS 1.3 connections and for all + // resumed connections. + // + // Deprecated: there are conditions in which this value might not be unique + // to a connection. See the Security Considerations sections of RFC 5705 and + // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. + TLSUnique []byte + + // ekm is a closure exposed via ExportKeyingMaterial. + ekm func(label string, context []byte, length int) ([]byte, error) +} + +type ConnectionStateWith0RTT struct { + ConnectionState + + Used0RTT bool // true if 0-RTT was both offered and accepted +} + +// ClientAuthType is tls.ClientAuthType +type ClientAuthType = tls.ClientAuthType + +const ( + NoClientCert = tls.NoClientCert + RequestClientCert = tls.RequestClientCert + RequireAnyClientCert = tls.RequireAnyClientCert + VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven + RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert +) + +// requiresClientCert reports whether the ClientAuthType requires a client +// certificate to be provided. +func requiresClientCert(c ClientAuthType) bool { + switch c { + case RequireAnyClientCert, RequireAndVerifyClientCert: + return true + default: + return false + } +} + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState = tls.ClientSessionState + +type clientSessionState struct { + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret + serverCertificates []*x509.Certificate // Certificate chain presented by the server + verifiedChains [][]*x509.Certificate // Certificate chains we built for verification + receivedAt time.Time // When the session ticket was received from the server + ocspResponse []byte // Stapled OCSP response presented by the server + scts [][]byte // SCTs presented by the server + + // TLS 1.3 fields. + nonce []byte // Ticket nonce sent by the server, to derive PSK + useBy time.Time // Expiration of the ticket lifetime as set by the server + ageAdd uint32 // Random obfuscation factor for sending the ticket age +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not +// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which +// are supported via this interface. +//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/marten-seemann/qtls-go1-17 ClientSessionCache" +type ClientSessionCache = tls.ClientSessionCache + +// SignatureScheme is a tls.SignatureScheme +type SignatureScheme = tls.SignatureScheme + +const ( + // RSASSA-PKCS1-v1_5 algorithms. + PKCS1WithSHA256 SignatureScheme = 0x0401 + PKCS1WithSHA384 SignatureScheme = 0x0501 + PKCS1WithSHA512 SignatureScheme = 0x0601 + + // RSASSA-PSS algorithms with public key OID rsaEncryption. + PSSWithSHA256 SignatureScheme = 0x0804 + PSSWithSHA384 SignatureScheme = 0x0805 + PSSWithSHA512 SignatureScheme = 0x0806 + + // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. + ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 + ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 + ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 + + // EdDSA algorithms. + Ed25519 SignatureScheme = 0x0807 + + // Legacy signature and hash algorithms for TLS 1.2. + PKCS1WithSHA1 SignatureScheme = 0x0201 + ECDSAWithSHA1 SignatureScheme = 0x0203 +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide application logic in the GetCertificate and GetConfigForClient callbacks. +type ClientHelloInfo = tls.ClientHelloInfo + +type clientHelloInfo struct { + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). + CipherSuites []uint16 + + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see RFC 4366, Section 3.1). + ServerName string + + // SupportedCurves lists the elliptic curves supported by the client. + // SupportedCurves is set only if the Supported Elliptic Curves + // Extension is being used (see RFC 4492, Section 5.1.1). + SupportedCurves []CurveID + + // SupportedPoints lists the point formats supported by the client. + // SupportedPoints is set only if the Supported Point Formats Extension + // is being used (see RFC 4492, Section 5.1.2). + SupportedPoints []uint8 + + // SignatureSchemes lists the signature and hash schemes that the client + // is willing to verify. SignatureSchemes is set only if the Signature + // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). + SignatureSchemes []SignatureScheme + + // SupportedProtos lists the application protocols supported by the client. + // SupportedProtos is set only if the Application-Layer Protocol + // Negotiation Extension is being used (see RFC 7301, Section 3.1). + // + // Servers can select a protocol by setting Config.NextProtos in a + // GetConfigForClient return value. + SupportedProtos []string + + // SupportedVersions lists the TLS versions supported by the client. + // For TLS versions less than 1.3, this is extrapolated from the max + // version advertised by the client, so values other than the greatest + // might be rejected if used. + SupportedVersions []uint16 + + // Conn is the underlying net.Conn for the connection. Do not read + // from, or write to, this connection; that will cause the TLS + // connection to fail. + Conn net.Conn + + // config is embedded by the GetCertificate or GetConfigForClient caller, + // for use with SupportsCertificate. + config *Config + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *clientHelloInfo) Context() context.Context { + return c.ctx +} + +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo = tls.CertificateRequestInfo + +type certificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte + + // SignatureSchemes lists the signature schemes that the server is + // willing to verify. + SignatureSchemes []SignatureScheme + + // Version is the TLS version that was negotiated for this connection. + Version uint16 + + // ctx is the context of the handshake that is in progress. + ctx context.Context +} + +// Context returns the context of the handshake that is in progress. +// This context is a child of the context passed to HandshakeContext, +// if any, and is canceled when the handshake concludes. +func (c *certificateRequestInfo) Context() context.Context { + return c.ctx +} + +// RenegotiationSupport enumerates the different levels of support for TLS +// renegotiation. TLS renegotiation is the act of performing subsequent +// handshakes on a connection after the first. This significantly complicates +// the state machine and has been the source of numerous, subtle security +// issues. Initiating a renegotiation is not supported, but support for +// accepting renegotiation requests may be enabled. +// +// Even when enabled, the server may not change its identity between handshakes +// (i.e. the leaf certificate must be the same). Additionally, concurrent +// handshake and application data flow is not permitted so renegotiation can +// only be used with protocols that synchronise with the renegotiation, such as +// HTTPS. +// +// Renegotiation is not defined in TLS 1.3. +type RenegotiationSupport = tls.RenegotiationSupport + +const ( + // RenegotiateNever disables renegotiation. + RenegotiateNever = tls.RenegotiateNever + + // RenegotiateOnceAsClient allows a remote server to request + // renegotiation once per connection. + RenegotiateOnceAsClient = tls.RenegotiateOnceAsClient + + // RenegotiateFreelyAsClient allows a remote server to repeatedly + // request renegotiation. + RenegotiateFreelyAsClient = tls.RenegotiateFreelyAsClient +) + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config = tls.Config + +type config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains to present to the + // other side of the connection. The first certificate compatible with the + // peer's requirements is selected automatically. + // + // Server configurations must set one of Certificates, GetCertificate or + // GetConfigForClient. Clients doing client-authentication may set either + // Certificates or GetClientCertificate. + // + // Note: if there are multiple Certificates, and they don't have the + // optional field Leaf set, certificate selection will incur a significant + // per-handshake performance cost. + Certificates []Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // + // Deprecated: NameToCertificate only allows associating a single + // certificate with a given name. Leave this field nil to let the library + // select the first compatible chain from Certificates. + NameToCertificate map[string]*Certificate + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // best element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // If SessionTicketKey was explicitly set on the returned Config, or if + // SetSessionTicketKeys was called on the returned Config, those keys will + // be used. Otherwise, the original Config keys will be used (and possibly + // rotated if they are automatically managed). + GetConfigForClient func(*ClientHelloInfo) (*Config, error) + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // VerifyConnection, if not nil, is called after normal certificate + // verification and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(ConnectionState) error + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported application level protocols, in + // order of preference. If both peers support ALPN, the selected + // protocol will be one from this list, and the connection will fail + // if there is no mutually supported protocol. If NextProtos is empty + // or the peer doesn't support ALPN, the connection will succeed and + // ConnectionState.NegotiatedProtocol will be empty. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting unless it is + // an IP address. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // InsecureSkipVerify controls whether a client verifies the server's + // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls + // accepts any certificate presented by the server and any host name in that + // certificate. In this mode, TLS is susceptible to machine-in-the-middle + // attacks unless custom verification is used. This should be used only for + // testing or in combination with VerifyConnection or VerifyPeerCertificate. + InsecureSkipVerify bool + + // CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of + // the list is ignored. Note that TLS 1.3 ciphersuites are not configurable. + // + // If CipherSuites is nil, a safe default list is used. The default cipher + // suites might change over time. + CipherSuites []uint16 + + // PreferServerCipherSuites is a legacy field and has no effect. + // + // It used to control whether the server would follow the client's or the + // server's preference. Servers now select the best mutually supported + // cipher suite based on logic that takes into account inferred client + // hardware, server hardware, and security. + // + // Deprecated: PreferServerCipherSuites is ignored. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket and + // PSK (resumption) support. Note that on clients, session ticket support is + // also disabled if ClientSessionCache is nil. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session resumption. + // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled + // with random data before the first server handshake. + // + // Deprecated: if this field is left at zero, session ticket keys will be + // automatically rotated every day and dropped after seven days. For + // customizing the rotation schedule or synchronizing servers that are + // terminating connections for the same host, use SetSessionTicketKeys. + SessionTicketKey [32]byte + + // ClientSessionCache is a cache of ClientSessionState entries for TLS + // session resumption. It is only used by clients. + ClientSessionCache ClientSessionCache + + // MinVersion contains the minimum TLS version that is acceptable. + // + // By default, TLS 1.2 is currently used as the minimum when acting as a + // client, and TLS 1.0 when acting as a server. TLS 1.0 is the minimum + // supported by this package, both as a client and as a server. + // + // The client-side default can temporarily be reverted to TLS 1.0 by + // including the value "x509sha1=1" in the GODEBUG environment variable. + // Note that this option will be removed in Go 1.19 (but it will still be + // possible to set this field to VersionTLS10 explicitly). + MinVersion uint16 + + // MaxVersion contains the maximum TLS version that is acceptable. + // + // By default, the maximum version supported by this package is used, + // which is currently TLS 1.3. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. The client will use the first preference as the type for + // its key share in TLS 1.3. This may change in the future. + CurvePreferences []CurveID + + // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. + // When true, the largest possible TLS record size is always used. When + // false, the size of TLS records may be adjusted in an attempt to + // improve latency. + DynamicRecordSizingDisabled bool + + // Renegotiation controls what types of renegotiation are supported. + // The default, none, is correct for the vast majority of applications. + Renegotiation RenegotiationSupport + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // mutex protects sessionTicketKeys and autoSessionTicketKeys. + mutex sync.RWMutex + // sessionTicketKeys contains zero or more ticket keys. If set, it means the + // the keys were set with SessionTicketKey or SetSessionTicketKeys. The + // first key is used for new tickets and any subsequent keys can be used to + // decrypt old tickets. The slice contents are not protected by the mutex + // and are immutable. + sessionTicketKeys []ticketKey + // autoSessionTicketKeys is like sessionTicketKeys but is owned by the + // auto-rotation logic. See Config.ticketKeys. + autoSessionTicketKeys []ticketKey +} + +// A RecordLayer handles encrypting and decrypting of TLS messages. +type RecordLayer interface { + SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) + SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) + ReadHandshakeMessage() ([]byte, error) + WriteRecord([]byte) (int, error) + SendAlert(uint8) +} + +type ExtraConfig struct { + // GetExtensions, if not nil, is called before a message that allows + // sending of extensions is sent. + // Currently only implemented for the ClientHello message (for the client) + // and for the EncryptedExtensions message (for the server). + // Only valid for TLS 1.3. + GetExtensions func(handshakeMessageType uint8) []Extension + + // ReceivedExtensions, if not nil, is called when a message that allows the + // inclusion of extensions is received. + // It is called with an empty slice of extensions, if the message didn't + // contain any extensions. + // Currently only implemented for the ClientHello message (sent by the + // client) and for the EncryptedExtensions message (sent by the server). + // Only valid for TLS 1.3. + ReceivedExtensions func(handshakeMessageType uint8, exts []Extension) + + // AlternativeRecordLayer is used by QUIC + AlternativeRecordLayer RecordLayer + + // Enforce the selection of a supported application protocol. + // Only works for TLS 1.3. + // If enabled, client and server have to agree on an application protocol. + // Otherwise, connection establishment fails. + EnforceNextProtoSelection bool + + // If MaxEarlyData is greater than 0, the client will be allowed to send early + // data when resuming a session. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning on the client. + MaxEarlyData uint32 + + // The Accept0RTT callback is called when the client offers 0-RTT. + // The server then has to decide if it wants to accept or reject 0-RTT. + // It is only used for servers. + Accept0RTT func(appData []byte) bool + + // 0RTTRejected is called when the server rejectes 0-RTT. + // It is only used for clients. + Rejected0RTT func() + + // If set, the client will export the 0-RTT key when resuming a session that + // allows sending of early data. + // Requires the AlternativeRecordLayer to be set. + // + // It has no meaning to the server. + Enable0RTT bool + + // Is called when the client saves a session ticket to the session ticket. + // This gives the application the opportunity to save some data along with the ticket, + // which can be restored when the session ticket is used. + GetAppDataForSessionState func() []byte + + // Is called when the client uses a session ticket. + // Restores the application data that was saved earlier on GetAppDataForSessionTicket. + SetAppDataFromSessionState func([]byte) +} + +// Clone clones. +func (c *ExtraConfig) Clone() *ExtraConfig { + return &ExtraConfig{ + GetExtensions: c.GetExtensions, + ReceivedExtensions: c.ReceivedExtensions, + AlternativeRecordLayer: c.AlternativeRecordLayer, + EnforceNextProtoSelection: c.EnforceNextProtoSelection, + MaxEarlyData: c.MaxEarlyData, + Enable0RTT: c.Enable0RTT, + Accept0RTT: c.Accept0RTT, + Rejected0RTT: c.Rejected0RTT, + GetAppDataForSessionState: c.GetAppDataForSessionState, + SetAppDataFromSessionState: c.SetAppDataFromSessionState, + } +} + +func (c *ExtraConfig) usesAlternativeRecordLayer() bool { + return c != nil && c.AlternativeRecordLayer != nil +} + +const ( + // ticketKeyNameLen is the number of bytes of identifier that is prepended to + // an encrypted session ticket in order to identify the key used to encrypt it. + ticketKeyNameLen = 16 + + // ticketKeyLifetime is how long a ticket key remains valid and can be used to + // resume a client connection. + ticketKeyLifetime = 7 * 24 * time.Hour // 7 days + + // ticketKeyRotation is how often the server should rotate the session ticket key + // that is used for new tickets. + ticketKeyRotation = 24 * time.Hour +) + +// ticketKey is the internal representation of a session ticket key. +type ticketKey struct { + // keyName is an opaque byte string that serves to identify the session + // ticket key. It's exposed as plaintext in every session ticket. + keyName [ticketKeyNameLen]byte + aesKey [16]byte + hmacKey [16]byte + // created is the time at which this ticket key was created. See Config.ticketKeys. + created time.Time +} + +// ticketKeyFromBytes converts from the external representation of a session +// ticket key to a ticketKey. Externally, session ticket keys are 32 random +// bytes and this function expands that into sufficient name and key material. +func (c *config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { + hashed := sha512.Sum512(b[:]) + copy(key.keyName[:], hashed[:ticketKeyNameLen]) + copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) + copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) + key.created = c.time() + return key +} + +// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session +// ticket, and the lifetime we set for tickets we send. +const maxSessionTicketLifetime = 7 * 24 * time.Hour + +// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is +// being used concurrently by a TLS client or server. +func (c *config) Clone() *config { + if c == nil { + return nil + } + c.mutex.RLock() + defer c.mutex.RUnlock() + return &config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + GetClientCertificate: c.GetClientCertificate, + GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, + VerifyConnection: c.VerifyConnection, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + KeyLogWriter: c.KeyLogWriter, + sessionTicketKeys: c.sessionTicketKeys, + autoSessionTicketKeys: c.autoSessionTicketKeys, + } +} + +// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was +// randomized for backwards compatibility but is not in use. +var deprecatedSessionTicketKey = []byte("DEPRECATED") + +// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is +// randomized if empty, and that sessionTicketKeys is populated from it otherwise. +func (c *config) initLegacySessionTicketKeyRLocked() { + // Don't write if SessionTicketKey is already defined as our deprecated string, + // or if it is defined by the user but sessionTicketKeys is already set. + if c.SessionTicketKey != [32]byte{} && + (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { + return + } + + // We need to write some data, so get an exclusive lock and re-check any conditions. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + if c.SessionTicketKey == [32]byte{} { + if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) + } + // Write the deprecated prefix at the beginning so we know we created + // it. This key with the DEPRECATED prefix isn't used as an actual + // session ticket key, and is only randomized in case the application + // reuses it for some reason. + copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) + } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { + c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} + } + +} + +// ticketKeys returns the ticketKeys for this connection. +// If configForClient has explicitly set keys, those will +// be returned. Otherwise, the keys on c will be used and +// may be rotated if auto-managed. +// During rotation, any expired session ticket keys are deleted from +// c.sessionTicketKeys. If the session ticket key that is currently +// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) +// is not fresh, then a new session ticket key will be +// created and prepended to c.sessionTicketKeys. +func (c *config) ticketKeys(configForClient *config) []ticketKey { + // If the ConfigForClient callback returned a Config with explicitly set + // keys, use those, otherwise just use the original Config. + if configForClient != nil { + configForClient.mutex.RLock() + if configForClient.SessionTicketsDisabled { + return nil + } + configForClient.initLegacySessionTicketKeyRLocked() + if len(configForClient.sessionTicketKeys) != 0 { + ret := configForClient.sessionTicketKeys + configForClient.mutex.RUnlock() + return ret + } + configForClient.mutex.RUnlock() + } + + c.mutex.RLock() + defer c.mutex.RUnlock() + if c.SessionTicketsDisabled { + return nil + } + c.initLegacySessionTicketKeyRLocked() + if len(c.sessionTicketKeys) != 0 { + return c.sessionTicketKeys + } + // Fast path for the common case where the key is fresh enough. + if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { + return c.autoSessionTicketKeys + } + + // autoSessionTicketKeys are managed by auto-rotation. + c.mutex.RUnlock() + defer c.mutex.RLock() + c.mutex.Lock() + defer c.mutex.Unlock() + // Re-check the condition in case it changed since obtaining the new lock. + if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { + var newKey [32]byte + if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { + panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) + } + valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) + valid = append(valid, c.ticketKeyFromBytes(newKey)) + for _, k := range c.autoSessionTicketKeys { + // While rotating the current key, also remove any expired ones. + if c.time().Sub(k.created) < ticketKeyLifetime { + valid = append(valid, k) + } + } + c.autoSessionTicketKeys = valid + } + return c.autoSessionTicketKeys +} + +// SetSessionTicketKeys updates the session ticket keys for a server. +// +// The first key will be used when creating new tickets, while all keys can be +// used for decrypting tickets. It is safe to call this function while the +// server is running in order to rotate the session ticket keys. The function +// will panic if keys is empty. +// +// Calling this function will turn off automatic session ticket key rotation. +// +// If multiple servers are terminating connections for the same host they should +// all have the same session ticket keys. If the session ticket keys leaks, +// previously recorded and future TLS connections using those keys might be +// compromised. +func (c *config) SetSessionTicketKeys(keys [][32]byte) { + if len(keys) == 0 { + panic("tls: keys must have at least one key") + } + + newKeys := make([]ticketKey, len(keys)) + for i, bytes := range keys { + newKeys[i] = c.ticketKeyFromBytes(bytes) + } + + c.mutex.Lock() + c.sessionTicketKeys = newKeys + c.mutex.Unlock() +} + +func (c *config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func (c *config) cipherSuites() []uint16 { + if needFIPS() { + return fipsCipherSuites(c) + } + if c.CipherSuites != nil { + return c.CipherSuites + } + return defaultCipherSuites +} + +var supportedVersions = []uint16{ + VersionTLS13, + VersionTLS12, + VersionTLS11, + VersionTLS10, +} + +// roleClient and roleServer are meant to call supportedVersions and parents +// with more readability at the callsite. +const roleClient = true +const roleServer = false + +func (c *config) supportedVersions(isClient bool) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) { + continue + } + if (c == nil || c.MinVersion == 0) && + isClient && v < VersionTLS12 { + continue + } + if c != nil && c.MinVersion != 0 && v < c.MinVersion { + continue + } + if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +func (c *config) maxSupportedVersion(isClient bool) uint16 { + supportedVersions := c.supportedVersions(isClient) + if len(supportedVersions) == 0 { + return 0 + } + return supportedVersions[0] +} + +// supportedVersionsFromMax returns a list of supported versions derived from a +// legacy maximum version value. Note that only versions supported by this +// library are returned. Any newer peer will use supportedVersions anyway. +func supportedVersionsFromMax(maxVersion uint16) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if v > maxVersion { + continue + } + versions = append(versions, v) + } + return versions +} + +var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} + +func (c *config) curvePreferences() []CurveID { + if needFIPS() { + return fipsCurvePreferences(c) + } + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +func (c *config) supportsCurve(curve CurveID) bool { + for _, cc := range c.curvePreferences() { + if cc == curve { + return true + } + } + return false +} + +// mutualVersion returns the protocol version to use given the advertised +// versions of the peer. Priority is given to the peer preference order. +func (c *config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) { + supportedVersions := c.supportedVersions(isClient) + for _, peerVersion := range peerVersions { + for _, v := range supportedVersions { + if v == peerVersion { + return v, true + } + } + } + return 0, false +} + +var errNoCertificates = errors.New("tls: no certificates configured") + +// getCertificate returns the best certificate for the given ClientHelloInfo, +// defaulting to the first element of c.Certificates. +func (c *config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { + cert, err := c.GetCertificate(clientHello) + if cert != nil || err != nil { + return cert, err + } + } + + if len(c.Certificates) == 0 { + return nil, errNoCertificates + } + + if len(c.Certificates) == 1 { + // There's only one choice, so no point doing any work. + return &c.Certificates[0], nil + } + + if c.NameToCertificate != nil { + name := strings.ToLower(clientHello.ServerName) + if cert, ok := c.NameToCertificate[name]; ok { + return cert, nil + } + if len(name) > 0 { + labels := strings.Split(name, ".") + labels[0] = "*" + wildcardName := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[wildcardName]; ok { + return cert, nil + } + } + } + + for _, cert := range c.Certificates { + if err := clientHello.SupportsCertificate(&cert); err == nil { + return &cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0], nil +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the client that sent the ClientHello. Otherwise, it returns an error +// describing the reason for the incompatibility. +// +// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate +// callback, this method will take into account the associated Config. Note that +// if GetConfigForClient returns a different Config, the change can't be +// accounted for by this method. +// +// This function will call x509.ParseCertificate unless c.Leaf is set, which can +// incur a significant performance cost. +func (chi *clientHelloInfo) SupportsCertificate(c *Certificate) error { + // Note we don't currently support certificate_authorities nor + // signature_algorithms_cert, and don't check the algorithms of the + // signatures on the chain (which anyway are a SHOULD, see RFC 8446, + // Section 4.4.2.2). + + config := chi.config + if config == nil { + config = &Config{} + } + conf := fromConfig(config) + vers, ok := conf.mutualVersion(roleServer, chi.SupportedVersions) + if !ok { + return errors.New("no mutually supported protocol versions") + } + + // If the client specified the name they are trying to connect to, the + // certificate needs to be valid for it. + if chi.ServerName != "" { + x509Cert, err := leafCertificate(c) + if err != nil { + return fmt.Errorf("failed to parse certificate: %w", err) + } + if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { + return fmt.Errorf("certificate is not valid for requested server name: %w", err) + } + } + + // supportsRSAFallback returns nil if the certificate and connection support + // the static RSA key exchange, and unsupported otherwise. The logic for + // supporting static RSA is completely disjoint from the logic for + // supporting signed key exchanges, so we just check it as a fallback. + supportsRSAFallback := func(unsupported error) error { + // TLS 1.3 dropped support for the static RSA key exchange. + if vers == VersionTLS13 { + return unsupported + } + // The static RSA key exchange works by decrypting a challenge with the + // RSA private key, not by signing, so check the PrivateKey implements + // crypto.Decrypter, like *rsa.PrivateKey does. + if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { + if _, ok := priv.Public().(*rsa.PublicKey); !ok { + return unsupported + } + } else { + return unsupported + } + // Finally, there needs to be a mutual cipher suite that uses the static + // RSA key exchange instead of ECDHE. + rsaCipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + return false + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if rsaCipherSuite == nil { + return unsupported + } + return nil + } + + // If the client sent the signature_algorithms extension, ensure it supports + // schemes we can use with this certificate and TLS version. + if len(chi.SignatureSchemes) > 0 { + if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { + return supportsRSAFallback(err) + } + } + + // In TLS 1.3 we are done because supported_groups is only relevant to the + // ECDHE computation, point format negotiation is removed, cipher suites are + // only relevant to the AEAD choice, and static RSA does not exist. + if vers == VersionTLS13 { + return nil + } + + // The only signed key exchange we support is ECDHE. + if !supportsECDHE(conf, chi.SupportedCurves, chi.SupportedPoints) { + return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) + } + + var ecdsaCipherSuite bool + if priv, ok := c.PrivateKey.(crypto.Signer); ok { + switch pub := priv.Public().(type) { + case *ecdsa.PublicKey: + var curve CurveID + switch pub.Curve { + case elliptic.P256(): + curve = CurveP256 + case elliptic.P384(): + curve = CurveP384 + case elliptic.P521(): + curve = CurveP521 + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + var curveOk bool + for _, c := range chi.SupportedCurves { + if c == curve && conf.supportsCurve(c) { + curveOk = true + break + } + } + if !curveOk { + return errors.New("client doesn't support certificate curve") + } + ecdsaCipherSuite = true + case ed25519.PublicKey: + if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { + return errors.New("connection doesn't support Ed25519") + } + ecdsaCipherSuite = true + case *rsa.PublicKey: + default: + return supportsRSAFallback(unsupportedCertificateError(c)) + } + } else { + return supportsRSAFallback(unsupportedCertificateError(c)) + } + + // Make sure that there is a mutually supported cipher suite that works with + // this certificate. Cipher suite selection will then apply the logic in + // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. + cipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { + if c.flags&suiteECDHE == 0 { + return false + } + if c.flags&suiteECSign != 0 { + if !ecdsaCipherSuite { + return false + } + } else { + if ecdsaCipherSuite { + return false + } + } + if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true + }) + if cipherSuite == nil { + return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) + } + + return nil +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +// +// Deprecated: NameToCertificate only allows associating a single certificate +// with a given name. Leave that field nil to let the library select the first +// compatible chain from Certificates. +func (c *config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := leafCertificate(cert) + if err != nil { + continue + } + // If SANs are *not* present, some clients will consider the certificate + // valid for the name in the Common Name. + if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +const ( + keyLogLabelTLS12 = "CLIENT_RANDOM" + keyLogLabelEarlyTraffic = "CLIENT_EARLY_TRAFFIC_SECRET" + keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" + keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" + keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" +) + +func (c *config) writeKeyLog(label string, clientRandom, secret []byte) error { + if c.KeyLogWriter == nil { + return nil + } + + logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) + + writerMutex.Lock() + _, err := c.KeyLogWriter.Write(logLine) + writerMutex.Unlock() + + return err +} + +// writerMutex protects all KeyLogWriters globally. It is rarely enabled, +// and is only for debugging, so a global mutex saves space. +var writerMutex sync.Mutex + +// A Certificate is a chain of one or more certificates, leaf first. +type Certificate = tls.Certificate + +// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing +// the corresponding c.Certificate[0]. +func leafCertificate(c *Certificate) (*x509.Certificate, error) { + if c.Leaf != nil { + return c.Leaf, nil + } + return x509.ParseCertificate(c.Certificate[0]) +} + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) bool +} + +// lruSessionCache is a ClientSessionCache implementation that uses an LRU +// caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state *ClientSessionState +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + } +} + +// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry +// corresponding to sessionKey is removed from the cache instead. +func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + if cs == nil { + c.q.Remove(elem) + delete(c.m, sessionKey) + } else { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + } + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the ClientSessionState value associated with a given key. It +// returns (nil, false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +func unexpectedMessageError(wanted, got any) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { + for _, s := range supportedSignatureAlgorithms { + if s == sigAlg { + return true + } + } + return false +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/conn.go b/vendor/github.com/marten-seemann/qtls-go1-19/conn.go new file mode 100644 index 00000000..1b275a9f --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/conn.go @@ -0,0 +1,1610 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package qtls + +import ( + "bytes" + "context" + "crypto/cipher" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn net.Conn + isClient bool + handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake + + // handshakeStatus is 1 if the connection is currently transferring + // application data (i.e. is not currently processing a handshake). + // handshakeStatus == 1 implies handshakeErr == nil. + // This field is only to be accessed with sync/atomic. + handshakeStatus uint32 + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *config // configuration passed to constructor + // handshakes counts the number of handshakes performed on the + // connection so far. If renegotiation is disabled then this is either + // zero or one. + extraConfig *ExtraConfig + + handshakes int + didResume bool // whether this connection was a session resumption + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + scts [][]byte // signed certificate timestamps from server + peerCertificates []*x509.Certificate + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // serverName contains the server name indicated by the client, if any. + serverName string + // secureRenegotiation is true if the server echoed the secure + // renegotiation extension. (This is meaningless as a server because + // renegotiation is not supported in that case.) + secureRenegotiation bool + // ekm is a closure for exporting keying material. + ekm func(label string, context []byte, length int) ([]byte, error) + // For the client: + // resumptionSecret is the resumption_master_secret for handling + // NewSessionTicket messages. nil if config.SessionTicketsDisabled. + // For the server: + // resumptionSecret is the resumption_master_secret for generating + // NewSessionTicket messages. Only used when the alternative record + // layer is set. nil if config.SessionTicketsDisabled. + resumptionSecret []byte + + // ticketKeys is the set of active session ticket keys for this + // connection. The first one is used to encrypt new tickets and + // all are tried to decrypt tickets. + ticketKeys []ticketKey + + // clientFinishedIsFirst is true if the client sent the first Finished + // message during the most recent handshake. This is recorded because + // the first transmitted Finished message is the tls-unique + // channel-binding value. + clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + + // clientFinished and serverFinished contain the Finished message sent + // by the client or server in the most recent handshake. This is + // retained to support the renegotiation extension and tls-unique + // channel-binding. + clientFinished [12]byte + serverFinished [12]byte + + // clientProtocol is the negotiated ALPN protocol. + clientProtocol string + + // input/output + in, out halfConn + rawInput bytes.Buffer // raw input, starting with a record header + input bytes.Reader // application data waiting to be read, from rawInput.Next + hand bytes.Buffer // handshake data waiting to be read + buffering bool // whether records are buffered in sendBuf + sendBuf []byte // a buffer of records waiting to be sent + + // bytesSent counts the bytes of application data sent. + // packetsSent counts packets. + bytesSent int64 + packetsSent int64 + + // retryCount counts the number of consecutive non-advancing records + // received by Conn.readRecord. That is, records that neither advance the + // handshake, nor deliver application data. Protected by in.Mutex. + retryCount int + + // activeCall is an atomic int32; the low bit is whether Close has + // been called. the rest of the bits are the number of goroutines + // in Conn.Write. + activeCall int32 + + used0RTT bool + + tmp [16]byte +} + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying connection. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// TLS session. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + + err error // first permanent error + version uint16 // protocol version + cipher any // cipher algorithm + mac hash.Hash + seq [8]byte // 64-bit sequence number + + scratchBuf [13]byte // to avoid allocs; interface method args escape + + nextCipher any // next encryption state + nextMac hash.Hash // next MAC algorithm + + trafficSecret []byte // current TLS 1.3 traffic secret + + setKeyCallback func(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) +} + +type permanentError struct { + err net.Error +} + +func (e *permanentError) Error() string { return e.err.Error() } +func (e *permanentError) Unwrap() error { return e.err } +func (e *permanentError) Timeout() bool { return e.err.Timeout() } +func (e *permanentError) Temporary() bool { return false } + +func (hc *halfConn) setErrorLocked(err error) error { + if e, ok := err.(net.Error); ok { + hc.err = &permanentError{err: e} + } else { + hc.err = err + } + return hc.err +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) { + hc.version = version + hc.nextCipher = cipher + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec() error { + if hc.nextCipher == nil || hc.version == VersionTLS13 { + return alertInternalError + } + hc.cipher = hc.nextCipher + hc.mac = hc.nextMac + hc.nextCipher = nil + hc.nextMac = nil + for i := range hc.seq { + hc.seq[i] = 0 + } + return nil +} + +func (hc *halfConn) exportKey(encLevel EncryptionLevel, suite *cipherSuiteTLS13, trafficSecret []byte) { + if hc.setKeyCallback != nil { + s := &CipherSuiteTLS13{ + ID: suite.id, + KeyLen: suite.keyLen, + Hash: suite.hash, + AEAD: func(key, fixedNonce []byte) cipher.AEAD { return suite.aead(key, fixedNonce) }, + } + hc.setKeyCallback(encLevel, s, trafficSecret) + } +} + +func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { + hc.trafficSecret = secret + key, iv := suite.trafficKey(secret) + hc.cipher = suite.aead(key, iv) + for i := range hc.seq { + hc.seq[i] = 0 + } +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq() { + for i := 7; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + return + } + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + panic("TLS: sequence number wraparound") +} + +// explicitNonceLen returns the number of bytes of explicit nonce or IV included +// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 +// and in certain AEAD modes in TLS 1.2. +func (hc *halfConn) explicitNonceLen() int { + if hc.cipher == nil { + return 0 + } + + switch c := hc.cipher.(type) { + case cipher.Stream: + return 0 + case aead: + return c.explicitNonceLen() + case cbcMode: + // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. + if hc.version >= VersionTLS11 { + return c.BlockSize() + } + return 0 + default: + panic("unknown cipher type") + } +} + +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. +func extractPadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) + + // The maximum possible padding length plus the actual length field + toCheck := 256 + // The length of the padded data is public, so we can use an if here + if toCheck > len(payload) { + toCheck = len(payload) + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + // Zero the padding length on error. This ensures any unchecked bytes + // are included in the MAC. Otherwise, an attacker that could + // distinguish MAC failures from padding failures could mount an attack + // similar to POODLE in SSL 3.0: given a good ciphertext that uses a + // full block's worth of padding, replace the final block with another + // block. If the MAC check passed but the padding check failed, the + // last byte of that block decrypted to the block size. + // + // See also macAndPaddingGood logic below. + paddingLen &= good + + toRemove = int(paddingLen) + 1 + return +} + +func roundUp(a, b int) int { + return a + (b-a%b)%b +} + +// cbcMode is an interface for block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// decrypt authenticates and decrypts the record if protection is active at +// this stage. The returned plaintext might overlap with the input. +func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { + var plaintext []byte + typ := recordType(record[0]) + payload := record[recordHeaderLen:] + + // In TLS 1.3, change_cipher_spec messages are to be ignored without being + // decrypted. See RFC 8446, Appendix D.4. + if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { + return payload, typ, nil + } + + paddingGood := byte(255) + paddingLen := 0 + + explicitNonceLen := hc.explicitNonceLen() + + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case aead: + if len(payload) < explicitNonceLen { + return nil, 0, alertBadRecordMAC + } + nonce := payload[:explicitNonceLen] + if len(nonce) == 0 { + nonce = hc.seq[:] + } + payload = payload[explicitNonceLen:] + + var additionalData []byte + if hc.version == VersionTLS13 { + additionalData = record[:recordHeaderLen] + } else { + additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:3]...) + n := len(payload) - c.Overhead() + additionalData = append(additionalData, byte(n>>8), byte(n)) + } + + var err error + plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) + if err != nil { + return nil, 0, alertBadRecordMAC + } + case cbcMode: + blockSize := c.BlockSize() + minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) + if len(payload)%blockSize != 0 || len(payload) < minPayload { + return nil, 0, alertBadRecordMAC + } + + if explicitNonceLen > 0 { + c.SetIV(payload[:explicitNonceLen]) + payload = payload[explicitNonceLen:] + } + c.CryptBlocks(payload, payload) + + // In a limited attempt to protect against CBC padding oracles like + // Lucky13, the data past paddingLen (which is secret) is passed to + // the MAC function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC roughly constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write, modulo cache effects. + paddingLen, paddingGood = extractPadding(payload) + default: + panic("unknown cipher type") + } + + if hc.version == VersionTLS13 { + if typ != recordTypeApplicationData { + return nil, 0, alertUnexpectedMessage + } + if len(plaintext) > maxPlaintext+1 { + return nil, 0, alertRecordOverflow + } + // Remove padding and find the ContentType scanning from the end. + for i := len(plaintext) - 1; i >= 0; i-- { + if plaintext[i] != 0 { + typ = recordType(plaintext[i]) + plaintext = plaintext[:i] + break + } + if i == 0 { + return nil, 0, alertUnexpectedMessage + } + } + } + } else { + plaintext = payload + } + + if hc.mac != nil { + macSize := hc.mac.Size() + if len(payload) < macSize { + return nil, 0, alertBadRecordMAC + } + + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } + record[3] = byte(n >> 8) + record[4] = byte(n) + remoteMAC := payload[n : n+macSize] + localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) + + // This is equivalent to checking the MACs and paddingGood + // separately, but in constant-time to prevent distinguishing + // padding failures from MAC failures. Depending on what value + // of paddingLen was returned on bad padding, distinguishing + // bad MAC from bad padding can lead to an attack. + // + // See also the logic at the end of extractPadding. + macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) + if macAndPaddingGood != 1 { + return nil, 0, alertBadRecordMAC + } + + plaintext = payload[:n] + } + + hc.incSeq() + return plaintext, typ, nil +} + +func (c *Conn) setAlternativeRecordLayer() { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + c.in.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetReadKey + c.out.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetWriteKey + } +} + +// sliceForAppend extends the input slice by n bytes. head is the full extended +// slice, while tail is the appended part. If the original slice has sufficient +// capacity no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and +// appends it to record, which must already contain the record header. +func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { + if hc.cipher == nil { + return append(record, payload...), nil + } + + var explicitNonce []byte + if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { + record, explicitNonce = sliceForAppend(record, explicitNonceLen) + if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { + // The AES-GCM construction in TLS has an explicit nonce so that the + // nonce can be random. However, the nonce is only 8 bytes which is + // too small for a secure, random nonce. Therefore we use the + // sequence number as the nonce. The 3DES-CBC construction also has + // an 8 bytes nonce but its nonces must be unpredictable (see RFC + // 5246, Appendix F.3), forcing us to use randomness. That's not + // 3DES' biggest problem anyway because the birthday bound on block + // collision is reached first due to its similarly small block size + // (see the Sweet32 attack). + copy(explicitNonce, hc.seq[:]) + } else { + if _, err := io.ReadFull(rand, explicitNonce); err != nil { + return nil, err + } + } + } + + var dst []byte + switch c := hc.cipher.(type) { + case cipher.Stream: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + record, dst = sliceForAppend(record, len(payload)+len(mac)) + c.XORKeyStream(dst[:len(payload)], payload) + c.XORKeyStream(dst[len(payload):], mac) + case aead: + nonce := explicitNonce + if len(nonce) == 0 { + nonce = hc.seq[:] + } + + if hc.version == VersionTLS13 { + record = append(record, payload...) + + // Encrypt the actual ContentType and replace the plaintext one. + record = append(record, record[0]) + record[0] = byte(recordTypeApplicationData) + + n := len(payload) + 1 + c.Overhead() + record[3] = byte(n >> 8) + record[4] = byte(n) + + record = c.Seal(record[:recordHeaderLen], + nonce, record[recordHeaderLen:], record[:recordHeaderLen]) + } else { + additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) + additionalData = append(additionalData, record[:recordHeaderLen]...) + record = c.Seal(record, nonce, payload, additionalData) + } + case cbcMode: + mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) + blockSize := c.BlockSize() + plaintextLen := len(payload) + len(mac) + paddingLen := blockSize - plaintextLen%blockSize + record, dst = sliceForAppend(record, plaintextLen+paddingLen) + copy(dst, payload) + copy(dst[len(payload):], mac) + for i := plaintextLen; i < len(dst); i++ { + dst[i] = byte(paddingLen - 1) + } + if len(explicitNonce) > 0 { + c.SetIV(explicitNonce) + } + c.CryptBlocks(dst, dst) + default: + panic("unknown cipher type") + } + + // Update length to include nonce, MAC and any block padding needed. + n := len(record) - recordHeaderLen + record[3] = byte(n >> 8) + record[4] = byte(n) + hc.incSeq() + + return record, nil +} + +// RecordHeaderError is returned when a TLS record header is invalid. +type RecordHeaderError struct { + // Msg contains a human readable string that describes the error. + Msg string + // RecordHeader contains the five bytes of TLS record header that + // triggered the error. + RecordHeader [5]byte + // Conn provides the underlying net.Conn in the case that a client + // sent an initial handshake that didn't look like TLS. + // It is nil if there's already been a handshake or a TLS alert has + // been written to the connection. + Conn net.Conn +} + +func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } + +func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { + err.Msg = msg + err.Conn = conn + copy(err.RecordHeader[:], c.rawInput.Bytes()) + return err +} + +func (c *Conn) readRecord() error { + return c.readRecordOrCCS(false) +} + +func (c *Conn) readChangeCipherSpec() error { + return c.readRecordOrCCS(true) +} + +// readRecordOrCCS reads one or more TLS records from the connection and +// updates the record layer state. Some invariants: +// - c.in must be locked +// - c.input must be empty +// +// During the handshake one and only one of the following will happen: +// - c.hand grows +// - c.in.changeCipherSpec is called +// - an error is returned +// +// After the handshake one and only one of the following will happen: +// - c.hand grows +// - c.input is set +// - an error is returned +func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { + if c.in.err != nil { + return c.in.err + } + handshakeComplete := c.handshakeComplete() + + // This function modifies c.rawInput, which owns the c.input memory. + if c.input.Len() != 0 { + return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data")) + } + c.input.Reset(nil) + + // Read header, payload. + if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { + // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify + // is an error, but popular web sites seem to do this, so we accept it + // if and only if at the record boundary. + if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 { + err = io.EOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + hdr := c.rawInput.Bytes()[:recordHeaderLen] + typ := recordType(hdr[0]) + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + if !handshakeComplete && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) + } + + vers := uint16(hdr[1])<<8 | uint16(hdr[2]) + n := int(hdr[3])<<8 | int(hdr[4]) + if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { + c.sendAlert(alertProtocolVersion) + msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if !c.haveVers { + // First message, be extra suspicious: this might not be a TLS + // client. Bail out before reading a full 'body', if possible. + // The current max version is 3.3 so if the version is >= 16.0, + // it's probably not real. + if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { + return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) + } + } + if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } + if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return err + } + + // Process message. + record := c.rawInput.Next(recordHeaderLen + n) + data, typ, err := c.in.decrypt(record) + if err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + if len(data) > maxPlaintext { + return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) + } + + // Application Data messages are always protected. + if c.in.cipher == nil && typ == recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { + // This is a state-advancing message: reset the retry count. + c.retryCount = 0 + } + + // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. + if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + switch typ { + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + + case recordTypeAlert: + if len(data) != 2 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if alert(data[1]) == alertCloseNotify { + return c.in.setErrorLocked(io.EOF) + } + if c.vers == VersionTLS13 { + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + } + switch data[0] { + case alertLevelWarning: + // Drop the record on the floor and retry. + return c.retryReadRecord(expectChangeCipherSpec) + case alertLevelError: + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + default: + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if len(data) != 1 || data[0] != 1 { + return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) + } + // Handshake messages are not allowed to fragment across the CCS. + if c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // In TLS 1.3, change_cipher_spec records are ignored until the + // Finished. See RFC 8446, Appendix D.4. Note that according to Section + // 5, a server can send a ChangeCipherSpec before its ServerHello, when + // c.vers is still unset. That's not useful though and suspicious if the + // server then selects a lower protocol version, so don't allow that. + if c.vers == VersionTLS13 { + return c.retryReadRecord(expectChangeCipherSpec) + } + if !expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + if err := c.in.changeCipherSpec(); err != nil { + return c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + + case recordTypeApplicationData: + if !handshakeComplete || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + // Some OpenSSL servers send empty records in order to randomize the + // CBC IV. Ignore a limited number of empty records. + if len(data) == 0 { + return c.retryReadRecord(expectChangeCipherSpec) + } + // Note that data is owned by c.rawInput, following the Next call above, + // to avoid copying the plaintext. This is safe because c.rawInput is + // not read from or written to until c.input is drained. + c.input.Reset(data) + + case recordTypeHandshake: + if len(data) == 0 || expectChangeCipherSpec { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + c.hand.Write(data) + } + + return nil +} + +// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like +// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. +func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many ignored records")) + } + return c.readRecordOrCCS(expectChangeCipherSpec) +} + +// atLeastReader reads from R, stopping with EOF once at least N bytes have been +// read. It is different from an io.LimitedReader in that it doesn't cut short +// the last Read call, and in that it considers an early EOF an error. +type atLeastReader struct { + R io.Reader + N int64 +} + +func (r *atLeastReader) Read(p []byte) (int, error) { + if r.N <= 0 { + return 0, io.EOF + } + n, err := r.R.Read(p) + r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 + if r.N > 0 && err == io.EOF { + return n, io.ErrUnexpectedEOF + } + if r.N <= 0 && err == nil { + return n, io.EOF + } + return n, err +} + +// readFromUntil reads from r into c.rawInput until c.rawInput contains +// at least n bytes or else returns an error. +func (c *Conn) readFromUntil(r io.Reader, n int) error { + if c.rawInput.Len() >= n { + return nil + } + needs := n - c.rawInput.Len() + // There might be extra input waiting on the wire. Make a best effort + // attempt to fetch it so that it can be used in (*Conn).Read to + // "predict" closeNotify alerts. + c.rawInput.Grow(needs + bytes.MinRead) + _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)}) + return err +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlertLocked(err alert) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + + _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr + } + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) +} + +// sendAlert sends a TLS alert message. +func (c *Conn) sendAlert(err alert) error { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + c.extraConfig.AlternativeRecordLayer.SendAlert(uint8(err)) + return &net.OpError{Op: "local error", Err: err} + } + + c.out.Lock() + defer c.out.Unlock() + return c.sendAlertLocked(err) +} + +const ( + // tcpMSSEstimate is a conservative estimate of the TCP maximum segment + // size (MSS). A constant is used, rather than querying the kernel for + // the actual MSS, to avoid complexity. The value here is the IPv6 + // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 + // bytes) and a TCP header with timestamps (32 bytes). + tcpMSSEstimate = 1208 + + // recordSizeBoostThreshold is the number of bytes of application data + // sent after which the TLS record size will be increased to the + // maximum. + recordSizeBoostThreshold = 128 * 1024 +) + +// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the +// next application data record. There is the following trade-off: +// +// - For latency-sensitive applications, such as web browsing, each TLS +// record should fit in one TCP segment. +// - For throughput-sensitive applications, such as large file transfers, +// larger TLS records better amortize framing and encryption overheads. +// +// A simple heuristic that works well in practice is to use small records for +// the first 1MB of data, then use larger records for subsequent data, and +// reset back to smaller records after the connection becomes idle. See "High +// Performance Web Networking", Chapter 4, or: +// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ +// +// In the interests of simplicity and determinism, this code does not attempt +// to reset the record size once the connection is idle, however. +func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { + if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { + return maxPlaintext + } + + if c.bytesSent >= recordSizeBoostThreshold { + return maxPlaintext + } + + // Subtract TLS overheads to get the maximum payload size. + payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() + if c.out.cipher != nil { + switch ciph := c.out.cipher.(type) { + case cipher.Stream: + payloadBytes -= c.out.mac.Size() + case cipher.AEAD: + payloadBytes -= ciph.Overhead() + case cbcMode: + blockSize := ciph.BlockSize() + // The payload must fit in a multiple of blockSize, with + // room for at least one padding byte. + payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 + // The MAC is appended before padding so affects the + // payload size directly. + payloadBytes -= c.out.mac.Size() + default: + panic("unknown cipher type") + } + } + if c.vers == VersionTLS13 { + payloadBytes-- // encrypted ContentType + } + + // Allow packet growth in arithmetic progression up to max. + pkt := c.packetsSent + c.packetsSent++ + if pkt > 1000 { + return maxPlaintext // avoid overflow in multiply below + } + + n := payloadBytes * int(pkt+1) + if n > maxPlaintext { + n = maxPlaintext + } + return n +} + +func (c *Conn) write(data []byte) (int, error) { + if c.buffering { + c.sendBuf = append(c.sendBuf, data...) + return len(data), nil + } + + n, err := c.conn.Write(data) + c.bytesSent += int64(n) + return n, err +} + +func (c *Conn) flush() (int, error) { + if len(c.sendBuf) == 0 { + return 0, nil + } + + n, err := c.conn.Write(c.sendBuf) + c.bytesSent += int64(n) + c.sendBuf = nil + c.buffering = false + return n, err +} + +// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. +var outBufPool = sync.Pool{ + New: func() any { + return new([]byte) + }, +} + +// writeRecordLocked writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { + outBufPtr := outBufPool.Get().(*[]byte) + outBuf := *outBufPtr + defer func() { + // You might be tempted to simplify this by just passing &outBuf to Put, + // but that would make the local copy of the outBuf slice header escape + // to the heap, causing an allocation. Instead, we keep around the + // pointer to the slice header returned by Get, which is already on the + // heap, and overwrite and return that. + *outBufPtr = outBuf + outBufPool.Put(outBufPtr) + }() + + var n int + for len(data) > 0 { + m := len(data) + if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { + m = maxPayload + } + + _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) + outBuf[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } else if vers == VersionTLS13 { + // TLS 1.3 froze the record layer version to 1.2. + // See RFC 8446, Section 5.1. + vers = VersionTLS12 + } + outBuf[1] = byte(vers >> 8) + outBuf[2] = byte(vers) + outBuf[3] = byte(m >> 8) + outBuf[4] = byte(m) + + var err error + outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) + if err != nil { + return n, err + } + if _, err := c.write(outBuf); err != nil { + return n, err + } + n += m + data = data[m:] + } + + if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) + } + } + + return n, nil +} + +// writeRecord writes a TLS record with the given type and payload to the +// connection and updates the record layer state. +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + if typ == recordTypeChangeCipherSpec { + return len(data), nil + } + return c.extraConfig.AlternativeRecordLayer.WriteRecord(data) + } + + c.out.Lock() + defer c.out.Unlock() + + return c.writeRecordLocked(typ, data) +} + +// readHandshake reads the next handshake message from +// the record layer. +func (c *Conn) readHandshake() (any, error) { + var data []byte + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + var err error + data, err = c.extraConfig.AlternativeRecordLayer.ReadHandshakeMessage() + if err != nil { + return nil, err + } + } else { + for c.hand.Len() < 4 { + if err := c.readRecord(); err != nil { + return nil, err + } + } + + data = c.hand.Bytes() + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + c.sendAlertLocked(alertInternalError) + return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) + } + for c.hand.Len() < 4+n { + if err := c.readRecord(); err != nil { + return nil, err + } + } + data = c.hand.Next(4 + n) + } + var m handshakeMessage + switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) + case typeClientHello: + m = new(clientHelloMsg) + case typeServerHello: + m = new(serverHelloMsg) + case typeNewSessionTicket: + if c.vers == VersionTLS13 { + m = new(newSessionTicketMsgTLS13) + } else { + m = new(newSessionTicketMsg) + } + case typeCertificate: + if c.vers == VersionTLS13 { + m = new(certificateMsgTLS13) + } else { + m = new(certificateMsg) + } + case typeCertificateRequest: + if c.vers == VersionTLS13 { + m = new(certificateRequestMsgTLS13) + } else { + m = &certificateRequestMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + } + case typeCertificateStatus: + m = new(certificateStatusMsg) + case typeServerKeyExchange: + m = new(serverKeyExchangeMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeCertificateVerify: + m = &certificateVerifyMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } + case typeFinished: + m = new(finishedMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) + case typeKeyUpdate: + m = new(keyUpdateMsg) + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshalers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if !m.unmarshal(data) { + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + return m, nil +} + +var ( + errShutdown = errors.New("tls: protocol is shutdown") +) + +// Write writes data to the connection. +// +// As Write calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Write is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Write(b []byte) (int, error) { + // interlock with Close below + for { + x := atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return 0, net.ErrClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { + break + } + } + defer atomic.AddInt32(&c.activeCall, -2) + + if err := c.Handshake(); err != nil { + return 0, err + } + + c.out.Lock() + defer c.out.Unlock() + + if err := c.out.err; err != nil { + return 0, err + } + + if !c.handshakeComplete() { + return 0, alertInternalError + } + + if c.closeNotifySent { + return 0, errShutdown + } + + // TLS 1.0 is susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the IV. + // + // https://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // https://www.imperialviolet.org/2012/01/15/beastfollowup.html + + var m int + if len(b) > 1 && c.vers == VersionTLS10 { + if _, ok := c.out.cipher.(cipher.BlockMode); ok { + n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.out.setErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecordLocked(recordTypeApplicationData, b) + return n + m, c.out.setErrorLocked(err) +} + +// handleRenegotiation processes a HelloRequest handshake message. +func (c *Conn) handleRenegotiation() error { + if c.vers == VersionTLS13 { + return errors.New("tls: internal error: unexpected renegotiation") + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + helloReq, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(helloReq, msg) + } + + if !c.isClient { + return c.sendAlert(alertNoRenegotiation) + } + + switch c.config.Renegotiation { + case RenegotiateNever: + return c.sendAlert(alertNoRenegotiation) + case RenegotiateOnceAsClient: + if c.handshakes > 1 { + return c.sendAlert(alertNoRenegotiation) + } + case RenegotiateFreelyAsClient: + // Ok. + default: + c.sendAlert(alertInternalError) + return errors.New("tls: unknown Renegotiation value") + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + atomic.StoreUint32(&c.handshakeStatus, 0) + if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { + c.handshakes++ + } + return c.handshakeErr +} + +func (c *Conn) HandlePostHandshakeMessage() error { + return c.handlePostHandshakeMessage() +} + +// handlePostHandshakeMessage processes a handshake message arrived after the +// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. +func (c *Conn) handlePostHandshakeMessage() error { + if c.vers != VersionTLS13 { + return c.handleRenegotiation() + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) + } + + switch msg := msg.(type) { + case *newSessionTicketMsgTLS13: + return c.handleNewSessionTicket(msg) + case *keyUpdateMsg: + return c.handleKeyUpdate(msg) + default: + c.sendAlert(alertUnexpectedMessage) + return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) + } +} + +func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil { + return c.in.setErrorLocked(c.sendAlert(alertInternalError)) + } + + newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) + c.in.setTrafficSecret(cipherSuite, newSecret) + + if keyUpdate.updateRequested { + c.out.Lock() + defer c.out.Unlock() + + msg := &keyUpdateMsg{} + _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + if err != nil { + // Surface the error at the next write. + c.out.setErrorLocked(err) + return nil + } + + newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) + c.out.setTrafficSecret(cipherSuite, newSecret) + } + + return nil +} + +// Read reads data from the connection. +// +// As Read calls Handshake, in order to prevent indefinite blocking a deadline +// must be set for both Read and Write before Read is called when the handshake +// has not yet completed. See SetDeadline, SetReadDeadline, and +// SetWriteDeadline. +func (c *Conn) Read(b []byte) (int, error) { + if err := c.Handshake(); err != nil { + return 0, err + } + if len(b) == 0 { + // Put this after Handshake, in case people were calling + // Read(nil) for the side effect of the Handshake. + return 0, nil + } + + c.in.Lock() + defer c.in.Unlock() + + for c.input.Len() == 0 { + if err := c.readRecord(); err != nil { + return 0, err + } + for c.hand.Len() > 0 { + if err := c.handlePostHandshakeMessage(); err != nil { + return 0, err + } + } + } + + n, _ := c.input.Read(b) + + // If a close-notify alert is waiting, read it so that we can return (n, + // EOF) instead of (n, nil), to signal to the HTTP response reading + // goroutine that the connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would otherwise not observe + // the EOF until its next read, by which time a client goroutine might + // have already tried to reuse the HTTP connection for a new request. + // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 + if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && + recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { + if err := c.readRecord(); err != nil { + return n, err // will be io.EOF on closeNotify + } + } + + return n, nil +} + +// Close closes the connection. +func (c *Conn) Close() error { + // Interlock with Conn.Write above. + var x int32 + for { + x = atomic.LoadInt32(&c.activeCall) + if x&1 != 0 { + return net.ErrClosed + } + if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { + break + } + } + if x != 0 { + // io.Writer and io.Closer should not be used concurrently. + // If Close is called while a Write is currently in-flight, + // interpret that as a sign that this Close is really just + // being used to break the Write and/or clean up resources and + // avoid sending the alertCloseNotify, which may block + // waiting on handshakeMutex or the c.out mutex. + return c.conn.Close() + } + + var alertErr error + if c.handshakeComplete() { + if err := c.closeNotify(); err != nil { + alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) + } + } + + if err := c.conn.Close(); err != nil { + return err + } + return alertErr +} + +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + if !c.handshakeComplete() { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + c.out.Lock() + defer c.out.Unlock() + + if !c.closeNotifySent { + // Set a Write Deadline to prevent possibly blocking forever. + c.SetWriteDeadline(time.Now().Add(time.Second * 5)) + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + // Any subsequent writes will fail. + c.SetWriteDeadline(time.Now()) + } + return c.closeNotifyErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first Read or Write will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// HandshakeContext or the Dialer's DialContext method instead. +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first Read or Write will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + // Delegate to unexported method for named return + // without confusing documented signature. + return c.handshakeContext(ctx) +} + +func (c *Conn) handshakeContext(ctx context.Context) (ret error) { + // Fast sync/atomic-based exit if there is no handshake in flight and the + // last one succeeded without an error. Avoids the expensive context setup + // and mutex for most Read and Write calls. + if c.handshakeComplete() { + return nil + } + + handshakeCtx, cancel := context.WithCancel(ctx) + // Note: defer this before starting the "interrupter" goroutine + // so that we can tell the difference between the input being canceled and + // this cancellation. In the former case, we need to close the connection. + defer cancel() + + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot). + // + // The interrupter goroutine waits for the input context to be done and + // closes the connection if this happens before the function returns. + if ctx.Done() != nil { + done := make(chan struct{}) + interruptRes := make(chan error, 1) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil { + // Return context error to user. + ret = ctxErr + } + }() + go func() { + select { + case <-handshakeCtx.Done(): + // Close the connection, discarding the error + _ = c.conn.Close() + interruptRes <- handshakeCtx.Err() + case <-done: + interruptRes <- nil + } + }() + } + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeComplete() { + return nil + } + + c.in.Lock() + defer c.in.Unlock() + + c.handshakeErr = c.handshakeFn(handshakeCtx) + if c.handshakeErr == nil { + c.handshakes++ + } else { + // If an error occurred during the handshake try to flush the + // alert that might be left in the buffer. + c.flush() + } + + if c.handshakeErr == nil && !c.handshakeComplete() { + c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") + } + if c.handshakeErr != nil && c.handshakeComplete() { + panic("tls: internal error: handshake returned an error but is marked successful") + } + + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return c.connectionStateLocked() +} + +// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection. +func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + return ConnectionStateWith0RTT{ + ConnectionState: c.connectionStateLocked(), + Used0RTT: c.used0RTT, + } +} + +func (c *Conn) connectionStateLocked() ConnectionState { + var state connectionState + state.HandshakeComplete = c.handshakeComplete() + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = true + state.ServerName = c.serverName + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] + } else { + state.TLSUnique = c.serverFinished[:] + } + } + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } + return toConnectionState(state) +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) VerifyHostname(host string) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.isClient { + return errors.New("tls: VerifyHostname called on TLS server connection") + } + if !c.handshakeComplete() { + return errors.New("tls: handshake has not yet been performed") + } + if len(c.verifiedChains) == 0 { + return errors.New("tls: handshake did not verify certificate chain") + } + return c.peerCertificates[0].VerifyHostname(host) +} + +func (c *Conn) handshakeComplete() bool { + return atomic.LoadUint32(&c.handshakeStatus) == 1 +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/cpu.go b/vendor/github.com/marten-seemann/qtls-go1-19/cpu.go new file mode 100644 index 00000000..12194508 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/cpu.go @@ -0,0 +1,22 @@ +//go:build !js +// +build !js + +package qtls + +import ( + "runtime" + + "golang.org/x/sys/cpu" +) + +var ( + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && + (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || + runtime.GOARCH == "arm64" && hasGCMAsmARM64 || + runtime.GOARCH == "s390x" && hasGCMAsmS390X +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/cpu_other.go b/vendor/github.com/marten-seemann/qtls-go1-19/cpu_other.go new file mode 100644 index 00000000..33f7d219 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/cpu_other.go @@ -0,0 +1,12 @@ +//go:build js +// +build js + +package qtls + +var ( + hasGCMAsmAMD64 = false + hasGCMAsmARM64 = false + hasGCMAsmS390X = false + + hasAESGCMHardwareSupport = false +) diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client.go b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client.go new file mode 100644 index 00000000..4407683a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client.go @@ -0,0 +1,1117 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "net" + "strings" + "sync/atomic" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +const clientSessionStateVersion = 1 + +type clientHandshakeState struct { + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + finishedHash finishedHash + masterSecret []byte + session *clientSessionState +} + +var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme + +func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { + config := c.config + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + + nextProtosLength := 0 + for _, proto := range config.NextProtos { + if l := len(proto); l == 0 || l > 255 { + return nil, nil, errors.New("tls: invalid NextProtos value") + } else { + nextProtosLength += 1 + l + } + } + if nextProtosLength > 0xffff { + return nil, nil, errors.New("tls: NextProtos values too large") + } + + var supportedVersions []uint16 + var clientHelloVersion uint16 + if c.extraConfig.usesAlternativeRecordLayer() { + if config.maxSupportedVersion(roleClient) < VersionTLS13 { + return nil, nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") + } + // Only offer TLS 1.3 when QUIC is used. + supportedVersions = []uint16{VersionTLS13} + clientHelloVersion = VersionTLS13 + } else { + supportedVersions = config.supportedVersions(roleClient) + if len(supportedVersions) == 0 { + return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") + } + clientHelloVersion = config.maxSupportedVersion(roleClient) + } + + // The version at the beginning of the ClientHello was capped at TLS 1.2 + // for compatibility reasons. The supported_versions extension is used + // to negotiate versions now. See RFC 8446, Section 4.2.1. + if clientHelloVersion > VersionTLS12 { + clientHelloVersion = VersionTLS12 + } + + hello := &clientHelloMsg{ + vers: clientHelloVersion, + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + ocspStapling: true, + scts: true, + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + secureRenegotiationSupported: true, + alpnProtocols: config.NextProtos, + supportedVersions: supportedVersions, + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } + + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport { + preferenceOrder = cipherSuitesPreferenceOrderNoAES + } + configCipherSuites := config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(configCipherSuites)) + + for _, suiteId := range preferenceOrder { + suite := mutualCipherSuite(configCipherSuites, suiteId) + if suite == nil { + continue + } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + continue + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) + } + + _, err := io.ReadFull(config.rand(), hello.random) + if err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + + // A random session ID is used to detect when the server accepted a ticket + // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as + // a compatibility measure (see RFC 8446, Section 4.1.2). + if c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + hello.sessionId = make([]byte, 32) + if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { + return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) + } + } + + if hello.vers >= VersionTLS12 { + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + } + if testingOnlyForceClientHelloSignatureAlgorithms != nil { + hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms + } + + var params ecdheParameters + if hello.supportedVersions[0] == VersionTLS13 { + var suites []uint16 + for _, suiteID := range configCipherSuites { + for _, suite := range cipherSuitesTLS13 { + if suite.id == suiteID { + suites = append(suites, suiteID) + } + } + } + if len(suites) > 0 { + hello.cipherSuites = suites + } else { + if hasAESGCMHardwareSupport { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) + } else { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) + } + } + + curveID := config.curvePreferences()[0] + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err = generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + if hello.supportedVersions[0] == VersionTLS13 && c.extraConfig != nil && c.extraConfig.GetExtensions != nil { + hello.additionalExtensions = c.extraConfig.GetExtensions(typeClientHello) + } + + return hello, params, nil +} + +func (c *Conn) clientHandshake(ctx context.Context) (err error) { + if c.config == nil { + c.config = fromConfig(defaultConfig()) + } + c.setAlternativeRecordLayer() + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, ecdheParams, err := c.makeClientHello() + if err != nil { + return err + } + c.serverName = hello.serverName + + cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + if cacheKey != "" && session != nil { + var deletedTicket bool + if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT { + // don't reuse a session ticket that enabled 0-RTT + c.config.ClientSessionCache.Put(cacheKey, nil) + deletedTicket = true + + if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil { + h := suite.hash.New() + h.Write(hello.marshal()) + clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h) + c.out.exportKey(Encryption0RTT, suite, clientEarlySecret) + if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + } + } + if !deletedTicket { + defer func() { + // If we got a handshake failure when resuming a session, throw away + // the session ticket. See RFC 5077, Section 3.2. + // + // RFC 8446 makes no mention of dropping tickets on failure, but it + // does require servers to abort on invalid binders, so we need to + // delete tickets to recover from a corrupted PSK. + if err != nil { + c.config.ClientSessionCache.Put(cacheKey, nil) + } + }() + } + } + + if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + + if err := c.pickTLSVersion(serverHello); err != nil { + return err + } + + // If we are negotiating a protocol version that's lower than what we + // support, check for the server downgrade canaries. + // See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion(roleClient) + tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 + tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 + if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || + maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") + } + + if c.vers == VersionTLS13 { + hs := &clientHandshakeStateTLS13{ + c: c, + ctx: ctx, + serverHello: serverHello, + hello: hello, + ecdheParams: ecdheParams, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + } + + // In TLS 1.3, session tickets are delivered after the handshake. + return hs.handshake() + } + + hs := &clientHandshakeState{ + c: c, + ctx: ctx, + serverHello: serverHello, + hello: hello, + session: session, + } + + if err := hs.handshake(); err != nil { + return err + } + + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one. + if cacheKey != "" && hs.session != nil && session != hs.session { + c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session)) + } + + return nil +} + +// extract the app data saved in the session.nonce, +// and set the session.nonce to the actual nonce value +func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max early data */, []byte /* app data */, bool /* ok */) { + s := cryptobyte.String(session.nonce) + var version uint16 + if !s.ReadUint16(&version) { + return 0, nil, false + } + if version != clientSessionStateVersion { + return 0, nil, false + } + var maxEarlyData uint32 + if !s.ReadUint32(&maxEarlyData) { + return 0, nil, false + } + var appData []byte + if !readUint16LengthPrefixed(&s, &appData) { + return 0, nil, false + } + var nonce []byte + if !readUint16LengthPrefixed(&s, &nonce) { + return 0, nil, false + } + session.nonce = nonce + return maxEarlyData, appData, true +} + +func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, + session *clientSessionState, earlySecret, binderKey []byte) { + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return "", nil, nil, nil + } + + hello.ticketSupported = true + + if hello.supportedVersions[0] == VersionTLS13 { + // Require DHE on resumption as it guarantees forward secrecy against + // compromise of the session ticket key. See RFC 8446, Section 4.2.9. + hello.pskModes = []uint8{pskModeDHE} + } + + // Session resumption is not allowed if renegotiating because + // renegotiation is primarily used to allow a client to send a client + // certificate, which would be skipped if session resumption occurred. + if c.handshakes != 0 { + return "", nil, nil, nil + } + + // Try to resume a previously negotiated TLS session, if available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + sess, ok := c.config.ClientSessionCache.Get(cacheKey) + if !ok || sess == nil { + return cacheKey, nil, nil, nil + } + session = fromClientSessionState(sess) + + var appData []byte + var maxEarlyData uint32 + if session.vers == VersionTLS13 { + var ok bool + maxEarlyData, appData, ok = c.decodeSessionState(session) + if !ok { // delete it, if parsing failed + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + } + + // Check that version used for the previous session is still valid. + versOk := false + for _, v := range hello.supportedVersions { + if v == session.vers { + versOk = true + break + } + } + if !versOk { + return cacheKey, nil, nil, nil + } + + // Check that the cached server certificate is not expired, and that it's + // valid for the ServerName. This should be ensured by the cache key, but + // protect the application from a faulty ClientSessionCache implementation. + if !c.config.InsecureSkipVerify { + if len(session.verifiedChains) == 0 { + // The original connection had InsecureSkipVerify, while this doesn't. + return cacheKey, nil, nil, nil + } + serverCert := session.serverCertificates[0] + if c.config.time().After(serverCert.NotAfter) { + // Expired certificate, delete the entry. + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { + return cacheKey, nil, nil, nil + } + } + + if session.vers != VersionTLS13 { + // In TLS 1.2 the cipher suite must match the resumed session. Ensure we + // are still offering it. + if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { + return cacheKey, nil, nil, nil + } + + hello.sessionTicket = session.sessionTicket + return + } + + // Check that the session ticket is not expired. + if c.config.time().After(session.useBy) { + c.config.ClientSessionCache.Put(cacheKey, nil) + return cacheKey, nil, nil, nil + } + + // In TLS 1.3 the KDF hash must match the resumed session. Ensure we + // offer at least one cipher suite with that hash. + cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) + if cipherSuite == nil { + return cacheKey, nil, nil, nil + } + cipherSuiteOk := false + for _, offeredID := range hello.cipherSuites { + offeredSuite := cipherSuiteTLS13ByID(offeredID) + if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return cacheKey, nil, nil, nil + } + + // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. + ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) + identity := pskIdentity{ + label: session.sessionTicket, + obfuscatedTicketAge: ticketAge + session.ageAdd, + } + hello.pskIdentities = []pskIdentity{identity} + hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} + + // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. + psk := cipherSuite.expandLabel(session.masterSecret, "resumption", + session.nonce, cipherSuite.hash.Size()) + earlySecret = cipherSuite.extract(psk, nil) + binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) + if c.extraConfig != nil { + hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0 + } + transcript := cipherSuite.hash.New() + transcript.Write(hello.marshalWithoutBinders()) + pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} + hello.updateBinders(pskBinders) + + if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { + c.extraConfig.SetAppDataFromSessionState(appData) + } + return +} + +func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { + peerVersion := serverHello.vers + if serverHello.supportedVersion != 0 { + peerVersion = serverHello.supportedVersion + } + + vers, ok := c.config.mutualVersion(roleClient, []uint16{peerVersion}) + if !ok { + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) + } + + c.vers = vers + c.haveVers = true + c.in.version = vers + c.out.version = vers + + return nil +} + +// Does the handshake, either a full one or resumes old session. Requires hs.c, +// hs.hello, hs.serverHello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() error { + c := hs.c + + isResume, err := hs.processServerHello() + if err != nil { + return err + } + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + + // No signatures of the handshake are needed in a resumption. + // Otherwise, in a full handshake, if we don't have any certificates + // configured then we will never send a CertificateVerify message and + // thus no signatures are needed in that case either. + if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { + hs.finishedHash.discardHandshakeBuffer() + } + + hs.finishedHash.Write(hs.hello.marshal()) + hs.finishedHash.Write(hs.serverHello.marshal()) + + c.buffering = true + c.didResume = isResume + if isResume { + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = false + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } else { + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(c.clientFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = true + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(c.serverFinished[:]); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + + hs.c.cipherSuite = hs.suite.id + return nil +} + +func (hs *clientHandshakeState) doFullHandshake() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + certMsg, ok := msg.(*certificateMsg) + if !ok || len(certMsg.certificates) == 0 { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + msg, err = c.readHandshake() + if err != nil { + return err + } + + cs, ok := msg.(*certificateStatusMsg) + if ok { + // RFC4366 on Certificate Status Request: + // The server MAY return a "certificate_status" message. + + if !hs.serverHello.ocspStapling { + // If a server returns a "CertificateStatus" message, then the + // server MUST have included an extension of type "status_request" + // with empty "extension_data" in the extended server hello. + + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received unexpected CertificateStatus message") + } + hs.finishedHash.Write(cs.marshal()) + + c.ocspResponse = cs.response + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + if c.handshakes == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := c.verifyServerCertificate(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + + keyAgreement := hs.suite.ka(c.vers) + + skx, ok := msg.(*serverKeyExchangeMsg) + if ok { + hs.finishedHash.Write(skx.marshal()) + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) + if err != nil { + c.sendAlert(alertUnexpectedMessage) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + var chainToSend *Certificate + var certRequested bool + certReq, ok := msg.(*certificateRequestMsg) + if ok { + certRequested = true + hs.finishedHash.Write(certReq.marshal()) + + cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) + if chainToSend, err = c.getClientCertificate(cri); err != nil { + c.sendAlert(alertInternalError) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + shd, ok := msg.(*serverHelloDoneMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) + } + hs.finishedHash.Write(shd.marshal()) + + // If the server requested a certificate then we have to send a + // Certificate message, even if it's empty because we don't have a + // certificate to send. + if certRequested { + certMsg = new(certificateMsg) + certMsg.certificates = chainToSend.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + } + + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ckx != nil { + hs.finishedHash.Write(ckx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + return err + } + } + + if chainToSend != nil && len(chainToSend.Certificate) > 0 { + certVerify := &certificateVerifyMsg{} + + key, ok := chainToSend.PrivateKey.(crypto.Signer) + if !ok { + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + certVerify.hasSignatureAlgorithm = true + certVerify.signatureAlgorithm = signatureAlgorithm + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.finishedHash.Write(certVerify.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + return err + } + } + + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to write to key log: " + err.Error()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher any + var clientHash, serverHash hash.Hash + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) + return nil +} + +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} + +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if err := hs.pickCipherSuite(); err != nil { + return false, err + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") + } + + if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { + c.secureRenegotiation = true + if len(hs.serverHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: initial handshake had non-empty renegotiation extension") + } + } + + if c.handshakes > 0 && c.secureRenegotiation { + var expectedSecureRenegotiation [24]byte + copy(expectedSecureRenegotiation[:], c.clientFinished[:]) + copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) + if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: incorrect renegotiation extension contents") + } + } + + if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil { + c.sendAlert(alertUnsupportedExtension) + return false, err + } + c.clientProtocol = hs.serverHello.alpnProtocol + + c.scts = hs.serverHello.scts + + if !hs.serverResumedSession() { + return false, nil + } + + if hs.session.vers != c.vers { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different version") + } + + if hs.session.cipherSuite != hs.suite.id { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: server resumed a session with a different cipher suite") + } + + // Restore masterSecret, peerCerts, and ocspResponse from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + // Let the ServerHello SCTs override the session SCTs from the original + // connection, if any are provided + if len(c.scts) == 0 && len(hs.session.scts) != 0 { + c.scts = hs.session.scts + } + + return true, nil +} + +// checkALPN ensure that the server's choice of ALPN protocol is compatible with +// the protocols that we advertised in the Client Hello. +func checkALPN(clientProtos []string, serverProto string) error { + if serverProto == "" { + return nil + } + if len(clientProtos) == 0 { + return errors.New("tls: server advertised unrequested ALPN extension") + } + for _, proto := range clientProtos { + if proto == serverProto { + return nil + } + } + return errors.New("tls: server selected unadvertised ALPN protocol") +} + +func (hs *clientHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + + verify := hs.finishedHash.serverSum(hs.masterSecret) + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") + } + hs.finishedHash.Write(serverFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + if !hs.serverHello.ticketSupported { + return nil + } + + c := hs.c + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + hs.finishedHash.Write(sessionTicketMsg.marshal()) + + hs.session = &clientSessionState{ + sessionTicket: sessionTicketMsg.ticket, + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + return nil +} + +func (hs *clientHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + copy(out, finished.verifyData) + return nil +} + +// verifyServerCertificate parses and verifies the provided chain, setting +// c.verifiedChains and c.peerCertificates or sending the appropriate alert. +func (c *Conn) verifyServerCertificate(certificates [][]byte) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + if !c.config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + var err error + c.verifiedChains, err = certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + switch certs[0].PublicKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: + break + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) + } + + c.peerCertificates = certs + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS +// <= 1.2 CertificateRequest, making an effort to fill in missing information. +func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { + cri := &certificateRequestInfo{ + AcceptableCAs: certReq.certificateAuthorities, + Version: vers, + ctx: ctx, + } + + var rsaAvail, ecAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case certTypeRSASign: + rsaAvail = true + case certTypeECDSASign: + ecAvail = true + } + } + + if !certReq.hasSignatureAlgorithm { + // Prior to TLS 1.2, signature schemes did not exist. In this case we + // make up a list based on the acceptable certificate types, to help + // GetClientCertificate and SupportsCertificate select the right certificate. + // The hash part of the SignatureScheme is a lie here, because + // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. + switch { + case rsaAvail && ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case rsaAvail: + cri.SignatureSchemes = []SignatureScheme{ + PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, + } + case ecAvail: + cri.SignatureSchemes = []SignatureScheme{ + ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, + } + } + return toCertificateRequestInfo(cri) + } + + // Filter the signature schemes based on the certificate types. + // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). + cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) + for _, sigScheme := range certReq.supportedSignatureAlgorithms { + sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) + if err != nil { + continue + } + switch sigType { + case signatureECDSA, signatureEd25519: + if ecAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + case signatureRSAPSS, signaturePKCS1v15: + if rsaAvail { + cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) + } + } + } + + return toCertificateRequestInfo(cri) +} + +func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { + if c.config.GetClientCertificate != nil { + return c.config.GetClientCertificate(cri) + } + + for _, chain := range c.config.Certificates { + if err := cri.SupportsCertificate(&chain); err != nil { + continue + } + return &chain, nil + } + + // No acceptable certificate found. Don't send a certificate. + return new(Certificate), nil +} + +const clientSessionCacheKeyPrefix = "qtls-" + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *config) string { + if len(config.ServerName) > 0 { + return clientSessionCacheKeyPrefix + config.ServerName + } + return clientSessionCacheKeyPrefix + serverAddr.String() +} + +// hostnameInSNI converts name into an appropriate hostname for SNI. +// Literal IP addresses and absolute FQDNs are not permitted as SNI values. +// See RFC 6066, Section 3. +func hostnameInSNI(name string) string { + host := name + if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } + if i := strings.LastIndex(host, "%"); i > 0 { + host = host[:i] + } + if net.ParseIP(host) != nil { + return "" + } + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return name +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client_tls13.go new file mode 100644 index 00000000..7f05f2c6 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client_tls13.go @@ -0,0 +1,736 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/hmac" + "crypto/rsa" + "encoding/binary" + "errors" + "hash" + "sync/atomic" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +type clientHandshakeStateTLS13 struct { + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + ecdheParams ecdheParameters + + session *clientSessionState + earlySecret []byte + binderKey []byte + + certReq *certificateRequestMsgTLS13 + usingPSK bool + sentDummyCCS bool + suite *cipherSuiteTLS13 + transcript hash.Hash + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 +} + +// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. +func (hs *clientHandshakeStateTLS13) handshake() error { + c := hs.c + + if needFIPS() { + return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") + } + + // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, + // sections 4.1.2 and 4.1.3. + if c.handshakes > 0 { + c.sendAlert(alertProtocolVersion) + return errors.New("tls: server selected TLS 1.3 in a renegotiation") + } + + // Consistency check on the presence of a keyShare and its parameters. + if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + return c.sendAlert(alertInternalError) + } + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + hs.transcript = hs.suite.hash.New() + hs.transcript.Write(hs.hello.marshal()) + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.processHelloRetryRequest(); err != nil { + return err + } + } + + hs.transcript.Write(hs.serverHello.marshal()) + + c.buffering = true + if err := hs.processServerHello(); err != nil { + return err + } + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + if err := hs.establishHandshakeKeys(); err != nil { + return err + } + if err := hs.readServerParameters(); err != nil { + return err + } + if err := hs.readServerCertificate(); err != nil { + return err + } + if err := hs.readServerFinished(); err != nil { + return err + } + if err := hs.sendClientCertificate(); err != nil { + return err + } + if err := hs.sendClientFinished(); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// checkServerHelloOrHRR does validity checks that apply to both ServerHello and +// HelloRetryRequest messages. It sets hs.suite. +func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { + c := hs.c + + if hs.serverHello.supportedVersion == 0 { + c.sendAlert(alertMissingExtension) + return errors.New("tls: server selected TLS 1.3 using the legacy version field") + } + + if hs.serverHello.supportedVersion != VersionTLS13 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid version after a HelloRetryRequest") + } + + if hs.serverHello.vers != VersionTLS12 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an incorrect legacy version") + } + + if hs.serverHello.ocspStapling || + hs.serverHello.ticketSupported || + hs.serverHello.secureRenegotiationSupported || + len(hs.serverHello.secureRenegotiation) != 0 || + len(hs.serverHello.alpnProtocol) != 0 || + len(hs.serverHello.scts) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") + } + + if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not echo the legacy session ID") + } + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported compression format") + } + + selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) + if hs.suite != nil && selectedSuite != hs.suite { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server changed cipher suite after a HelloRetryRequest") + } + if selectedSuite == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server chose an unconfigured cipher suite") + } + hs.suite = selectedSuite + c.cipherSuite = hs.suite.id + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and +// resends hs.hello, and reads the new ServerHello into hs.serverHello. +func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. (The idea is that the server might offload transcript + // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + hs.transcript.Write(hs.serverHello.marshal()) + + // The only HelloRetryRequest extensions we support are key_share and + // cookie, and clients must abort the handshake if the HRR would not result + // in any change in the ClientHello. + if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest message") + } + + if hs.serverHello.cookie != nil { + hs.hello.cookie = hs.serverHello.cookie + } + + if hs.serverHello.serverShare.group != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received malformed key_share extension") + } + + // If the server sent a key_share extension selecting a group, ensure it's + // a group we advertised but did not send a key share for, and send a key + // share for it this time. + if curveID := hs.serverHello.selectedGroup; curveID != 0 { + curveOK := false + for _, id := range hs.hello.supportedCurves { + if id == curveID { + curveOK = true + break + } + } + if !curveOK { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + if hs.ecdheParams.CurveID() == curveID { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), curveID) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.ecdheParams = params + hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + } + + hs.hello.raw = nil + if len(hs.hello.pskIdentities) > 0 { + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash == hs.suite.hash { + // Update binders and obfuscated_ticket_age. + ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) + hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd + + transcript := hs.suite.hash.New() + transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + transcript.Write(chHash) + transcript.Write(hs.serverHello.marshal()) + transcript.Write(hs.hello.marshalWithoutBinders()) + pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} + hs.hello.updateBinders(pskBinders) + } else { + // Server selected a cipher suite incompatible with the PSK. + hs.hello.pskIdentities = nil + hs.hello.pskBinders = nil + } + } + + if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { + c.extraConfig.Rejected0RTT() + } + hs.hello.earlyData = false // disable 0-RTT + + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + hs.serverHello = serverHello + + if err := hs.checkServerHelloOrHRR(); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) processServerHello() error { + c := hs.c + + if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: server sent two HelloRetryRequest messages") + } + + if len(hs.serverHello.cookie) != 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: server sent a cookie in a normal ServerHello") + } + + if hs.serverHello.selectedGroup != 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: malformed key_share extension") + } + + if hs.serverHello.serverShare.group == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server did not send a key share") + } + if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected unsupported group") + } + + if !hs.serverHello.selectedIdentityPresent { + return nil + } + + if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK") + } + + if len(hs.hello.pskIdentities) != 1 || hs.session == nil { + return c.sendAlert(alertInternalError) + } + pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) + if pskSuite == nil { + return c.sendAlert(alertInternalError) + } + if pskSuite.hash != hs.suite.hash { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: server selected an invalid PSK and cipher suite pair") + } + + hs.usingPSK = true + c.didResume = true + c.peerCertificates = hs.session.serverCertificates + c.verifiedChains = hs.session.verifiedChains + c.ocspResponse = hs.session.ocspResponse + c.scts = hs.session.scts + return nil +} + +func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { + c := hs.c + + sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) + if sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid server key share") + } + + earlySecret := hs.earlySecret + if !hs.usingPSK { + earlySecret = hs.suite.extract(nil, nil) + } + handshakeSecret := hs.suite.extract(sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret) + c.out.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(handshakeSecret, "derived", nil)) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerParameters() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(encryptedExtensions, msg) + } + // Notify the caller if 0-RTT was rejected. + if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { + c.extraConfig.Rejected0RTT() + } + c.used0RTT = encryptedExtensions.earlyData + if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil { + hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions) + } + hs.transcript.Write(encryptedExtensions.marshal()) + + if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { + c.sendAlert(alertUnsupportedExtension) + return err + } + c.clientProtocol = encryptedExtensions.alpnProtocol + + if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection { + if len(encryptedExtensions.alpnProtocol) == 0 { + // the server didn't select an ALPN + c.sendAlert(alertNoApplicationProtocol) + return errors.New("ALPN negotiation failed. Server didn't offer any protocols") + } + } + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerCertificate() error { + c := hs.c + + // Either a PSK or a certificate is always used, but not both. + // See RFC 8446, Section 4.1.1. + if hs.usingPSK { + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certReq, ok := msg.(*certificateRequestMsgTLS13) + if ok { + hs.transcript.Write(certReq.marshal()) + + hs.certReq = certReq + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + if len(certMsg.certificate.Certificate) == 0 { + c.sendAlert(alertDecodeError) + return errors.New("tls: received empty certificates message") + } + hs.transcript.Write(certMsg.marshal()) + + c.scts = certMsg.certificate.SignedCertificateTimestamps + c.ocspResponse = certMsg.certificate.OCSPStaple + + if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + + return nil +} + +func (hs *clientHandshakeStateTLS13) readServerFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + if !hmac.Equal(expectedMAC, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid server finished hash") + } + + hs.transcript.Write(finished.marshal()) + + // Derive secrets that take context through the server Finished. + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionApplication, hs.suite, serverSecret) + c.in.setTrafficSecret(hs.suite, serverSecret) + + err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { + c := hs.c + + if hs.certReq == nil { + return nil + } + + cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{ + AcceptableCAs: hs.certReq.certificateAuthorities, + SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, + Version: c.vers, + ctx: hs.ctx, + })) + if err != nil { + return err + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *cert + certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + // If we sent an empty certificate message, skip the CertificateVerify. + if len(cert.Certificate) == 0 { + return nil + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + + certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) + if err != nil { + // getClientCertificate returned a certificate incompatible with the + // CertificateRequestInfo supported signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *clientHandshakeStateTLS13) sendClientFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) + c.out.setTrafficSecret(hs.suite, hs.trafficSecret) + + if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + } + + return nil +} + +func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { + if !c.isClient { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: received new session ticket from a client") + } + + if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { + return nil + } + + // See RFC 8446, Section 4.6.1. + if msg.lifetime == 0 { + return nil + } + lifetime := time.Duration(msg.lifetime) * time.Second + if lifetime > maxSessionTicketLifetime { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: received a session ticket with invalid lifetime") + } + + cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) + if cipherSuite == nil || c.resumptionSecret == nil { + return c.sendAlert(alertInternalError) + } + + // We need to save the max_early_data_size that the server sent us, in order + // to decide if we're going to try 0-RTT with this ticket. + // However, at the same time, the qtls.ClientSessionTicket needs to be equal to + // the tls.ClientSessionTicket, so we can't just add a new field to the struct. + // We therefore abuse the nonce field (which is a byte slice) + nonceWithEarlyData := make([]byte, len(msg.nonce)+4) + binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData) + copy(nonceWithEarlyData[4:], msg.nonce) + + var appData []byte + if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil { + appData = c.extraConfig.GetAppDataForSessionState() + } + var b cryptobyte.Builder + b.AddUint16(clientSessionStateVersion) // revision + b.AddUint32(msg.maxEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(appData) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(msg.nonce) + }) + + // Save the resumption_master_secret and nonce instead of deriving the PSK + // to do the least amount of work on NewSessionTicket messages before we + // know if the ticket will be used. Forward secrecy of resumed connections + // is guaranteed by the requirement for pskModeDHE. + session := &clientSessionState{ + sessionTicket: msg.label, + vers: c.vers, + cipherSuite: c.cipherSuite, + masterSecret: c.resumptionSecret, + serverCertificates: c.peerCertificates, + verifiedChains: c.verifiedChains, + receivedAt: c.config.time(), + nonce: b.BytesOrPanic(), + useBy: c.config.time().Add(lifetime), + ageAdd: msg.ageAdd, + ocspResponse: c.ocspResponse, + scts: c.scts, + } + + cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session)) + + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_messages.go b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_messages.go new file mode 100644 index 00000000..07193c8e --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_messages.go @@ -0,0 +1,1843 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "fmt" + "strings" + + "golang.org/x/crypto/cryptobyte" +) + +// The marshalingFunction type is an adapter to allow the use of ordinary +// functions as cryptobyte.MarshalingValue. +type marshalingFunction func(b *cryptobyte.Builder) error + +func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { + return f(b) +} + +// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If +// the length of the sequence is not the value specified, it produces an error. +func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { + b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { + if len(v) != n { + return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) + } + b.AddBytes(v) + return nil + })) +} + +// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. +func addUint64(b *cryptobyte.Builder, v uint64) { + b.AddUint32(uint32(v >> 32)) + b.AddUint32(uint32(v)) +} + +// readUint64 decodes a big-endian, 64-bit value into out and advances over it. +// It reports whether the read was successful. +func readUint64(s *cryptobyte.String, out *uint64) bool { + var hi, lo uint32 + if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { + return false + } + *out = uint64(hi)<<32 | uint64(lo) + return true +} + +// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) +} + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + serverName string + ocspStapling bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocols []string + scts bool + supportedVersions []uint16 + cookie []byte + keyShares []keyShare + earlyData bool + pskModes []uint8 + pskIdentities []pskIdentity + pskBinders [][]byte + additionalExtensions []Extension +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeClientHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + b.AddUint16(suite) + } + }) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.compressionMethods) + }) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + b.AddUint16(extensionServerName) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // name_type = host_name + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(1) // status_type = ocsp + b.AddUint16(0) // empty responder_id_list + b.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + b.AddUint16(extensionSupportedCurves) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + b.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + b.AddUint16(extensionSessionTicket) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + b.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ks := range m.keyShares { + b.AddUint16(uint16(ks.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + b.AddUint16(extensionPSKModes) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.pskModes) + }) + }) + } + for _, ext := range m.additionalExtensions { + b.AddUint16(ext.Type) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ext.Data) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(psk.label) + }) + b.AddUint32(psk.obfuscatedTicketAge) + } + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +// marshalWithoutBinders returns the ClientHello through the +// PreSharedKeyExtension.identities field, according to RFC 8446, Section +// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. +func (m *clientHelloMsg) marshalWithoutBinders() []byte { + bindersLen := 2 // uint16 length prefix + for _, binder := range m.pskBinders { + bindersLen += 1 // uint8 length prefix + bindersLen += len(binder) + } + + fullMessage := m.marshal() + return fullMessage[:len(fullMessage)-bindersLen] +} + +// updateBinders updates the m.pskBinders field, if necessary updating the +// cached marshaled representation. The supplied binders must have the same +// length as the current m.pskBinders. +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { + if len(pskBinders) != len(m.pskBinders) { + panic("tls: internal error: pskBinders length mismatch") + } + for i := range m.pskBinders { + if len(pskBinders[i]) != len(m.pskBinders[i]) { + panic("tls: internal error: pskBinders length mismatch") + } + } + m.pskBinders = pskBinders + if m.raw != nil { + lenWithoutBinders := len(m.marshalWithoutBinders()) + b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { + panic("tls: internal error: failed to update binders") + } + } +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + *m = clientHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) { + return false + } + + var cipherSuites cryptobyte.String + if !s.ReadUint16LengthPrefixed(&cipherSuites) { + return false + } + m.cipherSuites = []uint16{} + m.secureRenegotiationSupported = false + for !cipherSuites.Empty() { + var suite uint16 + if !cipherSuites.ReadUint16(&suite) { + return false + } + if suite == scsvRenegotiation { + m.secureRenegotiationSupported = true + } + m.cipherSuites = append(m.cipherSuites, suite) + } + + if !readUint8LengthPrefixed(&s, &m.compressionMethods) { + return false + } + + if s.Empty() { + // ClientHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + seenExts := make(map[uint16]bool) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + if seenExts[extension] { + return false + } + seenExts[extension] = true + + switch extension { + case extensionServerName: + // RFC 6066, Section 3 + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { + return false + } + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { + return false + } + if nameType != 0 { + continue + } + if len(m.serverName) != 0 { + // Multiple names of the same name_type are prohibited. + return false + } + m.serverName = string(serverName) + // An SNI value may not include a trailing dot. + if strings.HasSuffix(m.serverName, ".") { + return false + } + } + case extensionStatusRequest: + // RFC 4366, Section 3.6 + var statusType uint8 + var ignored cryptobyte.String + if !extData.ReadUint8(&statusType) || + !extData.ReadUint16LengthPrefixed(&ignored) || + !extData.ReadUint16LengthPrefixed(&ignored) { + return false + } + m.ocspStapling = statusType == statusTypeOCSP + case extensionSupportedCurves: + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + var curves cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { + return false + } + for !curves.Empty() { + var curve uint16 + if !curves.ReadUint16(&curve) { + return false + } + m.supportedCurves = append(m.supportedCurves, CurveID(curve)) + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + case extensionSessionTicket: + // RFC 5077, Section 3.2 + m.ticketSupported = true + extData.ReadBytes(&m.sessionTicket, len(extData)) + case extensionSignatureAlgorithms: + // RFC 5246, Section 7.4.1.4.1 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + // RFC 8446, Section 4.2.3 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionRenegotiationInfo: + // RFC 5746, Section 3.2 + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + // RFC 7301, Section 3.1 + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + for !protoList.Empty() { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(proto)) + } + case extensionSCT: + // RFC 6962, Section 3.3.1 + m.scts = true + case extensionSupportedVersions: + // RFC 8446, Section 4.2.1 + var versList cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { + return false + } + for !versList.Empty() { + var vers uint16 + if !versList.ReadUint16(&vers) { + return false + } + m.supportedVersions = append(m.supportedVersions, vers) + } + case extensionCookie: + // RFC 8446, Section 4.2.2 + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // RFC 8446, Section 4.2.8 + var clientShares cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&clientShares) { + return false + } + for !clientShares.Empty() { + var ks keyShare + if !clientShares.ReadUint16((*uint16)(&ks.group)) || + !readUint16LengthPrefixed(&clientShares, &ks.data) || + len(ks.data) == 0 { + return false + } + m.keyShares = append(m.keyShares, ks) + } + case extensionEarlyData: + // RFC 8446, Section 4.2.10 + m.earlyData = true + case extensionPSKModes: + // RFC 8446, Section 4.2.9 + if !readUint8LengthPrefixed(&extData, &m.pskModes) { + return false + } + case extensionPreSharedKey: + // RFC 8446, Section 4.2.11 + if !extensions.Empty() { + return false // pre_shared_key must be the last extension + } + var identities cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { + return false + } + for !identities.Empty() { + var psk pskIdentity + if !readUint16LengthPrefixed(&identities, &psk.label) || + !identities.ReadUint32(&psk.obfuscatedTicketAge) || + len(psk.label) == 0 { + return false + } + m.pskIdentities = append(m.pskIdentities, psk) + } + var binders cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { + return false + } + for !binders.Empty() { + var binder []byte + if !readUint8LengthPrefixed(&binders, &binder) || + len(binder) == 0 { + return false + } + m.pskBinders = append(m.pskBinders, binder) + } + default: + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: extension, Data: extData}) + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type serverHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + ocspStapling bool + ticketSupported bool + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocol string + scts [][]byte + supportedVersion uint16 + serverShare keyShare + selectedIdentityPresent bool + selectedIdentity uint16 + supportedPoints []uint8 + + // HelloRetryRequest extensions + cookie []byte + selectedGroup CurveID +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeServerHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16(m.cipherSuite) + b.AddUint8(m.compressionMethod) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + b.AddUint16(extensionSessionTicket) + b.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range m.scts { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.serverShare.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *serverHelloMsg) unmarshal(data []byte) bool { + *m = serverHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) || + !s.ReadUint16(&m.cipherSuite) || + !s.ReadUint8(&m.compressionMethod) { + return false + } + + if s.Empty() { + // ServerHello is optionally followed by extension data + return true + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + seenExts := make(map[uint16]bool) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + if seenExts[extension] { + return false + } + seenExts[extension] = true + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSessionTicket: + m.ticketSupported = true + case extensionRenegotiationInfo: + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { + return false + } + m.secureRenegotiationSupported = true + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + m.scts = append(m.scts, sct) + } + case extensionSupportedVersions: + if !extData.ReadUint16(&m.supportedVersion) { + return false + } + case extensionCookie: + if !readUint16LengthPrefixed(&extData, &m.cookie) || + len(m.cookie) == 0 { + return false + } + case extensionKeyShare: + // This extension has different formats in SH and HRR, accept either + // and let the handshake logic decide. See RFC 8446, Section 4.2.8. + if len(extData) == 2 { + if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { + return false + } + } else { + if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || + !readUint16LengthPrefixed(&extData, &m.serverShare.data) { + return false + } + } + case extensionPreSharedKey: + m.selectedIdentityPresent = true + if !extData.ReadUint16(&m.selectedIdentity) { + return false + } + case extensionSupportedPoints: + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type encryptedExtensionsMsg struct { + raw []byte + alpnProtocol string + earlyData bool + + additionalExtensions []Extension +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeEncryptedExtensions) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } + for _, ext := range m.additionalExtensions { + b.AddUint16(ext.Type) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ext.Data) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + *m = encryptedExtensionsMsg{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var ext uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&ext) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch ext { + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + case extensionEarlyData: + m.earlyData = true + default: + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type endOfEarlyDataMsg struct{} + +func (m *endOfEarlyDataMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeEndOfEarlyData + return x +} + +func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type keyUpdateMsg struct { + raw []byte + updateRequested bool +} + +func (m *keyUpdateMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeKeyUpdate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.updateRequested { + b.AddUint8(1) + } else { + b.AddUint8(0) + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *keyUpdateMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var updateRequested uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&updateRequested) || !s.Empty() { + return false + } + switch updateRequested { + case 0: + m.updateRequested = false + case 1: + m.updateRequested = true + default: + return false + } + return true +} + +type newSessionTicketMsgTLS13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + nonce []byte + label []byte + maxEarlyData uint32 +} + +func (m *newSessionTicketMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeNewSessionTicket) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.lifetime) + b.AddUint32(m.ageAdd) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.nonce) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.label) + }) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.maxEarlyData > 0 { + b.AddUint16(extensionEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.maxEarlyData) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { + *m = newSessionTicketMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint32(&m.lifetime) || + !s.ReadUint32(&m.ageAdd) || + !readUint8LengthPrefixed(&s, &m.nonce) || + !readUint16LengthPrefixed(&s, &m.label) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionEarlyData: + if !extData.ReadUint32(&m.maxEarlyData) { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateRequestMsgTLS13 struct { + raw []byte + ocspStapling bool + scts bool + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateRequest) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + // certificate_request_context (SHALL be zero length unless used for + // post-handshake authentication) + b.AddUint8(0) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.scts { + // RFC 8446, Section 4.4.2.1 makes no mention of + // signed_certificate_timestamp in CertificateRequest, but + // "Extensions in the Certificate message from the client MUST + // correspond to extensions in the CertificateRequest message + // from the server." and it appears in the table in Section 4.2. + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedSignatureAlgorithms) > 0 { + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.certificateAuthorities) > 0 { + b.AddUint16(extensionCertificateAuthorities) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ca := range m.certificateAuthorities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ca) + }) + } + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { + *m = certificateRequestMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context, extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSCT: + m.scts = true + case extensionSignatureAlgorithms: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + case extensionCertificateAuthorities: + var auths cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { + return false + } + for !auths.Empty() { + var ca []byte + if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { + return false + } + m.certificateAuthorities = append(m.certificateAuthorities, ca) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return true +} + +type certificateMsgTLS13 struct { + raw []byte + certificate Certificate + ocspStapling bool + scts bool +} + +func (m *certificateMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // certificate_request_context + + certificate := m.certificate + if !m.ocspStapling { + certificate.OCSPStaple = nil + } + if !m.scts { + certificate.SignedCertificateTimestamps = nil + } + marshalCertificate(b, certificate) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for i, cert := range certificate.Certificate { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if i > 0 { + // This library only supports OCSP and SCT for leaf certificates. + return + } + if certificate.OCSPStaple != nil { + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(certificate.OCSPStaple) + }) + }) + } + if certificate.SignedCertificateTimestamps != nil { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range certificate.SignedCertificateTimestamps { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + }) + } + }) +} + +func (m *certificateMsgTLS13) unmarshal(data []byte) bool { + *m = certificateMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !unmarshalCertificate(&s, &m.certificate) || + !s.Empty() { + return false + } + + m.scts = m.certificate.SignedCertificateTimestamps != nil + m.ocspStapling = m.certificate.OCSPStaple != nil + + return true +} + +func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + var extensions cryptobyte.String + if !readUint24LengthPrefixed(&certList, &cert) || + !certList.ReadUint16LengthPrefixed(&extensions) { + return false + } + certificate.Certificate = append(certificate.Certificate, cert) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + if len(certificate.Certificate) > 1 { + // This library only supports OCSP and SCT for leaf certificates. + continue + } + + switch extension { + case extensionStatusRequest: + var statusType uint8 + if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || + len(certificate.OCSPStaple) == 0 { + return false + } + case extensionSCT: + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + certificate.SignedCertificateTimestamps = append( + certificate.SignedCertificateTimestamps, sct) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + } + return true +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.key = data[4:] + return true +} + +type certificateStatusMsg struct { + raw []byte + response []byte +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateStatus) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.response) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var statusType uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&s, &m.response) || + len(m.response) == 0 || !s.Empty() { + return false + } + return true +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return false + } + m.ciphertext = data[4:] + return true +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeFinished) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.verifyData) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + return s.Skip(1) && + readUint24LengthPrefixed(&s, &m.verifyData) && + s.Empty() +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAlgorithm indicates whether this message includes a list of + // supported signature algorithms. This change was introduced with TLS 1.2. + hasSignatureAlgorithm bool + + certificateTypes []byte + supportedSignatureAlgorithms []SignatureScheme + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 4346, Section 7.4.4. + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAlgorithm { + length += 2 + 2*len(m.supportedSignatureAlgorithms) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAlgorithm { + n := len(m.supportedSignatureAlgorithms) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAlgo := range m.supportedSignatureAlgorithms { + y[0] = uint8(sigAlgo >> 8) + y[1] = uint8(sigAlgo) + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return false + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return false + } + + data = data[numCertTypes:] + + if m.hasSignatureAlgorithm { + if len(data) < 2 { + return false + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return false + } + if len(data) < int(sigAndHashLen) { + return false + } + numSigAlgos := sigAndHashLen / 2 + m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) + for i := range m.supportedSignatureAlgorithms { + m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) + data = data[2:] + } + } + + if len(data) < 2 { + return false + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return false + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return false + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + + return len(data) == 0 +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAlgorithm bool // format change introduced in TLS 1.2 + signatureAlgorithm SignatureScheme + signature []byte +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateVerify) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + b.AddUint16(uint16(m.signatureAlgorithm)) + } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.signature) + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + if !s.Skip(4) { // message type and uint24 length field + return false + } + if m.hasSignatureAlgorithm { + if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { + return false + } + } + return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See RFC 5077, Section 3.3. + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 10 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return false + } + + m.ticket = data[10:] + + return true +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server.go b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server.go new file mode 100644 index 00000000..e93874af --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server.go @@ -0,0 +1,905 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "hash" + "io" + "sync/atomic" + "time" +) + +// serverHandshakeState contains details of a server handshake in progress. +// It's discarded once the handshake has completed. +type serverHandshakeState struct { + c *Conn + ctx context.Context + clientHello *clientHelloMsg + hello *serverHelloMsg + suite *cipherSuite + ecdheOk bool + ecSignOk bool + rsaDecryptOk bool + rsaSignOk bool + sessionState *sessionState + finishedHash finishedHash + masterSecret []byte + cert *Certificate +} + +// serverHandshake performs a TLS handshake as a server. +func (c *Conn) serverHandshake(ctx context.Context) error { + c.setAlternativeRecordLayer() + + clientHello, err := c.readClientHello(ctx) + if err != nil { + return err + } + + if c.vers == VersionTLS13 { + hs := serverHandshakeStateTLS13{ + c: c, + ctx: ctx, + clientHello: clientHello, + } + return hs.handshake() + } else if c.extraConfig.usesAlternativeRecordLayer() { + // This should already have been caught by the check that the ClientHello doesn't + // offer any (supported) versions older than TLS 1.3. + // Check again to make sure we can't be tricked into using an older version. + c.sendAlert(alertProtocolVersion) + return errors.New("tls: negotiated TLS < 1.3 when using QUIC") + } + + hs := serverHandshakeState{ + c: c, + ctx: ctx, + clientHello: clientHello, + } + return hs.handshake() +} + +func (hs *serverHandshakeState) handshake() error { + c := hs.c + + if err := hs.processClientHello(); err != nil { + return err + } + + // For an overview of TLS handshaking, see RFC 5246, Section 7.3. + c.buffering = true + if hs.checkForResumption() { + // The client has included a session ticket and so we do an abbreviated handshake. + c.didResume = true + if err := hs.doResumeHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(c.serverFinished[:]); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + c.clientFinishedIsFirst = false + if err := hs.readFinished(nil); err != nil { + return err + } + } else { + // The client didn't include a session ticket, or it wasn't + // valid so we do a full handshake. + if err := hs.pickCipherSuite(); err != nil { + return err + } + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readFinished(c.clientFinished[:]); err != nil { + return err + } + c.clientFinishedIsFirst = true + c.buffering = true + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(nil); err != nil { + return err + } + if _, err := c.flush(); err != nil { + return err + } + } + + c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +// readClientHello reads a ClientHello message and selects the protocol version. +func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { + msg, err := c.readHandshake() + if err != nil { + return nil, err + } + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return nil, unexpectedMessageError(clientHello, msg) + } + + var configForClient *config + originalConfig := c.config + if c.config.GetConfigForClient != nil { + chi := newClientHelloInfo(ctx, c, clientHello) + if cfc, err := c.config.GetConfigForClient(chi); err != nil { + c.sendAlert(alertInternalError) + return nil, err + } else if cfc != nil { + configForClient = fromConfig(cfc) + c.config = configForClient + } + } + c.ticketKeys = originalConfig.ticketKeys(configForClient) + + clientVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + clientVersions = supportedVersionsFromMax(clientHello.vers) + } + if c.extraConfig.usesAlternativeRecordLayer() { + // In QUIC, the client MUST NOT offer any old TLS versions. + // Here, we can only check that none of the other supported versions of this library + // (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here. + for _, ver := range clientVersions { + if ver == VersionTLS13 { + continue + } + for _, v := range supportedVersions { + if ver == v { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver) + } + } + } + // Make the config we're using allows us to use TLS 1.3. + if c.config.maxSupportedVersion(roleServer) < VersionTLS13 { + c.sendAlert(alertInternalError) + return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") + } + } + c.vers, ok = c.config.mutualVersion(roleServer, clientVersions) + if !ok { + c.sendAlert(alertProtocolVersion) + return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) + } + c.haveVers = true + c.in.version = c.vers + c.out.version = c.vers + + return clientHello, nil +} + +func (hs *serverHandshakeState) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers + + foundCompression := false + // We only support null compression, so check that the client offered it. + for _, compression := range hs.clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true + break + } + } + + if !foundCompression { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client does not support uncompressed connections") + } + + hs.hello.random = make([]byte, 32) + serverRandom := hs.hello.random + // Downgrade protection canaries. See RFC 8446, Section 4.1.3. + maxVers := c.config.maxSupportedVersion(roleServer) + if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { + if c.vers == VersionTLS12 { + copy(serverRandom[24:], downgradeCanaryTLS12) + } else { + copy(serverRandom[24:], downgradeCanaryTLS11) + } + serverRandom = serverRandom[:24] + } + _, err := io.ReadFull(c.config.rand(), serverRandom) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported + hs.hello.compressionMethod = compressionNone + if len(hs.clientHello.serverName) > 0 { + c.serverName = hs.clientHello.serverName + } + + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + if err != nil { + c.sendAlert(alertNoApplicationProtocol) + return err + } + hs.hello.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + + hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + if hs.clientHello.scts { + hs.hello.scts = hs.cert.SignedCertificateTimestamps + } + + hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) + + if hs.ecdheOk { + // Although omitting the ec_point_formats extension is permitted, some + // old OpenSSL version will refuse to handshake if not present. + // + // Per RFC 4492, section 5.1.2, implementations MUST support the + // uncompressed point format. See golang.org/issue/31943. + hs.hello.supportedPoints = []uint8{pointFormatUncompressed} + } + + if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { + switch priv.Public().(type) { + case *ecdsa.PublicKey: + hs.ecSignOk = true + case ed25519.PublicKey: + hs.ecSignOk = true + case *rsa.PublicKey: + hs.rsaSignOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) + } + } + if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { + switch priv.Public().(type) { + case *rsa.PublicKey: + hs.rsaDecryptOk = true + default: + c.sendAlert(alertInternalError) + return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) + } + } + + return nil +} + +// negotiateALPN picks a shared ALPN protocol that both sides support in server +// preference order. If ALPN is not configured or the peer doesn't support it, +// it returns "" and no error. +func negotiateALPN(serverProtos, clientProtos []string) (string, error) { + if len(serverProtos) == 0 || len(clientProtos) == 0 { + return "", nil + } + var http11fallback bool + for _, s := range serverProtos { + for _, c := range clientProtos { + if s == c { + return s, nil + } + if s == "h2" && c == "http/1.1" { + http11fallback = true + } + } + } + // As a special case, let http/1.1 clients connect to h2 servers as if they + // didn't support ALPN. We used not to enforce protocol overlap, so over + // time a number of HTTP servers were configured with only "h2", but + // expected to accept connections from "http/1.1" clients. See Issue 46310. + if http11fallback { + return "", nil + } + return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos) +} + +// supportsECDHE returns whether ECDHE key exchanges can be used with this +// pre-TLS 1.3 client. +func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool { + supportsCurve := false + for _, curve := range supportedCurves { + if c.supportsCurve(curve) { + supportsCurve = true + break + } + } + + supportsPointFormat := false + for _, pointFormat := range supportedPoints { + if pointFormat == pointFormatUncompressed { + supportsPointFormat = true + break + } + } + + return supportsCurve && supportsPointFormat +} + +func (hs *serverHandshakeState) pickCipherSuite() error { + c := hs.c + + preferenceOrder := cipherSuitesPreferenceOrder + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceOrder = cipherSuitesPreferenceOrderNoAES + } + + configCipherSuites := c.config.cipherSuites() + preferenceList := make([]uint16, 0, len(configCipherSuites)) + for _, suiteID := range preferenceOrder { + for _, id := range configCipherSuites { + if id == suiteID { + preferenceList = append(preferenceList, id) + break + } + } + } + + hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk) + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // The client is doing a fallback connection. See RFC 7507. + if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + return nil +} + +func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { + if c.flags&suiteECDHE != 0 { + if !hs.ecdheOk { + return false + } + if c.flags&suiteECSign != 0 { + if !hs.ecSignOk { + return false + } + } else if !hs.rsaSignOk { + return false + } + } else if !hs.rsaDecryptOk { + return false + } + if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { + return false + } + return true +} + +// checkForResumption reports whether we should perform resumption on this connection. +func (hs *serverHandshakeState) checkForResumption() bool { + c := hs.c + + if c.config.SessionTicketsDisabled { + return false + } + + plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) + if plaintext == nil { + return false + } + hs.sessionState = &sessionState{usedOldKey: usedOldKey} + ok := hs.sessionState.unmarshal(plaintext) + if !ok { + return false + } + + createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + return false + } + + // Never resume a session for a different TLS version. + if c.vers != hs.sessionState.vers { + return false + } + + cipherSuiteOk := false + // Check that the client is still offering the ciphersuite in the session. + for _, id := range hs.clientHello.cipherSuites { + if id == hs.sessionState.cipherSuite { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return false + } + + // Check that we also support the ciphersuite from the session. + hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, + c.config.cipherSuites(), hs.cipherSuiteOk) + if hs.suite == nil { + return false + } + + sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + return false + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + return false + } + + return true +} + +func (hs *serverHandshakeState) doResumeHandshake() error { + c := hs.c + + hs.hello.cipherSuite = hs.suite.id + c.cipherSuite = hs.suite.id + // We echo the client's session ID in the ServerHello to let it know + // that we're doing a resumption. + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.ticketSupported = hs.sessionState.usedOldKey + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := c.processCertsFromClient(Certificate{ + Certificate: hs.sessionState.certificates, + }); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + hs.masterSecret = hs.sessionState.masterSecret + + return nil +} + +func (hs *serverHandshakeState) doFullHandshake() error { + c := hs.c + + if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { + hs.hello.ocspStapling = true + } + + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled + hs.hello.cipherSuite = hs.suite.id + + hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) + if c.config.ClientAuth == NoClientCert { + // No need to keep a full record of the handshake if client + // certificates won't be used. + hs.finishedHash.discardHandshakeBuffer() + } + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + certMsg := new(certificateMsg) + certMsg.certificates = hs.cert.Certificate + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + if hs.hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.response = hs.cert.OCSPStaple + hs.finishedHash.Write(certStatus.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + return err + } + } + + keyAgreement := hs.suite.ka(c.vers) + skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if skx != nil { + hs.finishedHash.Write(skx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + return err + } + } + + var certReq *certificateRequestMsg + if c.config.ClientAuth >= RequestClientCert { + // Request a client certificate + certReq = new(certificateRequestMsg) + certReq.certificateTypes = []byte{ + byte(certTypeRSASign), + byte(certTypeECDSASign), + } + if c.vers >= VersionTLS12 { + certReq.hasSignatureAlgorithm = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + } + + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + hs.finishedHash.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + helloDone := new(serverHelloDoneMsg) + hs.finishedHash.Write(helloDone.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + return err + } + + if _, err := c.flush(); err != nil { + return err + } + + var pub crypto.PublicKey // public key for client auth, if any + + msg, err := c.readHandshake() + if err != nil { + return err + } + + // If we requested a client certificate, then the client must send a + // certificate message, even if it's empty. + if c.config.ClientAuth >= RequestClientCert { + certMsg, ok := msg.(*certificateMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.finishedHash.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(Certificate{ + Certificate: certMsg.certificates, + }); err != nil { + return err + } + if len(certMsg.certificates) != 0 { + pub = c.peerCertificates[0].PublicKey + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + // Get client key exchange + ckx, ok := msg.(*clientKeyExchangeMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) + } + hs.finishedHash.Write(ckx.marshal()) + + preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) + if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + + // If we received a client cert in response to our certificate request message, + // the client will send us a certificateVerifyMsg immediately after the + // clientKeyExchangeMsg. This message is a digest of all preceding + // handshake-layer messages that is signed using the private key corresponding + // to the client's certificate. This allows us to verify that the client is in + // possession of the private key of the certificate. + if len(c.peerCertificates) > 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + var sigType uint8 + var sigHash crypto.Hash + if c.vers >= VersionTLS12 { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err + } + } + + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.finishedHash.Write(certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *serverHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + + var clientCipher, serverCipher any + var clientHash, serverHash hash.Hash + + if hs.suite.aead == nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) + clientHash = hs.suite.mac(clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = hs.suite.mac(serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) + c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) + + return nil +} + +func (hs *serverHandshakeState) readFinished(out []byte) error { + c := hs.c + + if err := c.readChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + + verify := hs.finishedHash.clientSum(hs.masterSecret) + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client's Finished message is incorrect") + } + + hs.finishedHash.Write(clientFinished.marshal()) + copy(out, verify) + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket() error { + // ticketSupported is set in a resumption handshake if the + // ticket from the client was encrypted with an old session + // ticket key and thus a refreshed ticket should be sent. + if !hs.hello.ticketSupported { + return nil + } + + c := hs.c + m := new(newSessionTicketMsg) + + createdAt := uint64(c.config.time().Unix()) + if hs.sessionState != nil { + // If this is re-wrapping an old key, then keep + // the original time it was created. + createdAt = hs.sessionState.createdAt + } + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + createdAt: createdAt, + masterSecret: hs.masterSecret, + certificates: certsFromClient, + } + var err error + m.ticket, err = c.encryptTicket(state.marshal()) + if err != nil { + return err + } + + hs.finishedHash.Write(m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeState) sendFinished(out []byte) error { + c := hs.c + + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + hs.finishedHash.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + copy(out, finished.verifyData) + + return nil +} + +// processCertsFromClient takes a chain of client certificates either from a +// Certificates message or from a sessionState and verifies them. It returns +// the public key of the leaf certificate. +func (c *Conn) processCertsFromClient(certificate Certificate) error { + certificates := certificate.Certificate + certs := make([]*x509.Certificate, len(certificates)) + var err error + for i, asn1Data := range certificates { + if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse client certificate: " + err.Error()) + } + } + + if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: client didn't provide a certificate") + } + + if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { + opts := x509.VerifyOptions{ + Roots: c.config.ClientCAs, + CurrentTime: c.config.time(), + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + chains, err := certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to verify client certificate: " + err.Error()) + } + + c.verifiedChains = chains + } + + c.peerCertificates = certs + c.ocspResponse = certificate.OCSPStaple + c.scts = certificate.SignedCertificateTimestamps + + if len(certs) > 0 { + switch certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + } + + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + return nil +} + +func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { + supportedVersions := clientHello.supportedVersions + if len(clientHello.supportedVersions) == 0 { + supportedVersions = supportedVersionsFromMax(clientHello.vers) + } + + return toClientHelloInfo(&clientHelloInfo{ + CipherSuites: clientHello.cipherSuites, + ServerName: clientHello.serverName, + SupportedCurves: clientHello.supportedCurves, + SupportedPoints: clientHello.supportedPoints, + SignatureSchemes: clientHello.supportedSignatureAlgorithms, + SupportedProtos: clientHello.alpnProtocols, + SupportedVersions: supportedVersions, + Conn: c.conn, + config: toConfig(c.config), + ctx: ctx, + }) +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server_tls13.go new file mode 100644 index 00000000..e3db8063 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server_tls13.go @@ -0,0 +1,900 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "context" + "crypto" + "crypto/hmac" + "crypto/rsa" + "errors" + "hash" + "io" + "sync/atomic" + "time" +) + +// maxClientPSKIdentities is the number of client PSK identities the server will +// attempt to validate. It will ignore the rest not to let cheap ClientHello +// messages cause too much work in session ticket decryption attempts. +const maxClientPSKIdentities = 5 + +type serverHandshakeStateTLS13 struct { + c *Conn + ctx context.Context + clientHello *clientHelloMsg + hello *serverHelloMsg + alpnNegotiationErr error + encryptedExtensions *encryptedExtensionsMsg + sentDummyCCS bool + usingPSK bool + suite *cipherSuiteTLS13 + cert *Certificate + sigAlg SignatureScheme + earlySecret []byte + sharedKey []byte + handshakeSecret []byte + masterSecret []byte + trafficSecret []byte // client_application_traffic_secret_0 + transcript hash.Hash + clientFinished []byte +} + +func (hs *serverHandshakeStateTLS13) handshake() error { + c := hs.c + + if needFIPS() { + return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") + } + + // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. + if err := hs.processClientHello(); err != nil { + return err + } + if err := hs.checkForResumption(); err != nil { + return err + } + if err := hs.pickCertificate(); err != nil { + return err + } + c.buffering = true + if err := hs.sendServerParameters(); err != nil { + return err + } + if err := hs.sendServerCertificate(); err != nil { + return err + } + if err := hs.sendServerFinished(); err != nil { + return err + } + // Note that at this point we could start sending application data without + // waiting for the client's second flight, but the application might not + // expect the lack of replay protection of the ClientHello parameters. + if _, err := c.flush(); err != nil { + return err + } + if err := hs.readClientCertificate(); err != nil { + return err + } + if err := hs.readClientFinished(); err != nil { + return err + } + + atomic.StoreUint32(&c.handshakeStatus, 1) + + return nil +} + +func (hs *serverHandshakeStateTLS13) processClientHello() error { + c := hs.c + + hs.hello = new(serverHelloMsg) + hs.encryptedExtensions = new(encryptedExtensionsMsg) + + // TLS 1.3 froze the ServerHello.legacy_version field, and uses + // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. + hs.hello.vers = VersionTLS12 + hs.hello.supportedVersion = c.vers + + if len(hs.clientHello.supportedVersions) == 0 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") + } + + // Abort if the client is doing a fallback and landing lower than what we + // support. See RFC 7507, which however does not specify the interaction + // with supported_versions. The only difference is that with + // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] + // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, + // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to + // TLS 1.2, because a TLS 1.3 server would abort here. The situation before + // supported_versions was not better because there was just no way to do a + // TLS 1.4 handshake without risking the server selecting TLS 1.3. + for _, id := range hs.clientHello.cipherSuites { + if id == TLS_FALLBACK_SCSV { + // Use c.vers instead of max(supported_versions) because an attacker + // could defeat this by adding an arbitrary high version otherwise. + if c.vers < c.config.maxSupportedVersion(roleServer) { + c.sendAlert(alertInappropriateFallback) + return errors.New("tls: client using inappropriate protocol fallback") + } + break + } + } + + if len(hs.clientHello.compressionMethods) != 1 || + hs.clientHello.compressionMethods[0] != compressionNone { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: TLS 1.3 client supports illegal compression methods") + } + + hs.hello.random = make([]byte, 32) + if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { + c.sendAlert(alertInternalError) + return err + } + + if len(hs.clientHello.secureRenegotiation) != 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: initial handshake had non-empty renegotiation extension") + } + + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.compressionMethod = compressionNone + + if hs.suite == nil { + var preferenceList []uint16 + for _, suiteID := range c.config.CipherSuites { + for _, suite := range cipherSuitesTLS13 { + if suite.id == suiteID { + preferenceList = append(preferenceList, suiteID) + break + } + } + } + if len(preferenceList) == 0 { + preferenceList = defaultCipherSuitesTLS13 + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = defaultCipherSuitesTLS13NoAES + } + } + for _, suiteID := range preferenceList { + hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) + if hs.suite != nil { + break + } + } + } + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no cipher suite supported by both client and server") + } + c.cipherSuite = hs.suite.id + hs.hello.cipherSuite = hs.suite.id + hs.transcript = hs.suite.hash.New() + + // Pick the ECDHE group in server preference order, but give priority to + // groups with a key share, to avoid a HelloRetryRequest round-trip. + var selectedGroup CurveID + var clientKeyShare *keyShare +GroupSelection: + for _, preferredGroup := range c.config.curvePreferences() { + for _, ks := range hs.clientHello.keyShares { + if ks.group == preferredGroup { + selectedGroup = ks.group + clientKeyShare = &ks + break GroupSelection + } + } + if selectedGroup != 0 { + continue + } + for _, group := range hs.clientHello.supportedCurves { + if group == preferredGroup { + selectedGroup = group + break + } + } + } + if selectedGroup == 0 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: no ECDHE curve supported by both client and server") + } + if clientKeyShare == nil { + if err := hs.doHelloRetryRequest(selectedGroup); err != nil { + return err + } + clientKeyShare = &hs.clientHello.keyShares[0] + } + + if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} + hs.sharedKey = params.SharedKey(clientKeyShare.data) + if hs.sharedKey == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid client key share") + } + + c.serverName = hs.clientHello.serverName + + if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil { + c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions) + } + + selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) + if err != nil { + hs.alpnNegotiationErr = err + } + hs.encryptedExtensions.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + + return nil +} + +func (hs *serverHandshakeStateTLS13) checkForResumption() error { + c := hs.c + + if c.config.SessionTicketsDisabled { + return nil + } + + modeOK := false + for _, mode := range hs.clientHello.pskModes { + if mode == pskModeDHE { + modeOK = true + break + } + } + if !modeOK { + return nil + } + + if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: invalid or missing PSK binders") + } + if len(hs.clientHello.pskIdentities) == 0 { + return nil + } + + for i, identity := range hs.clientHello.pskIdentities { + if i >= maxClientPSKIdentities { + break + } + + plaintext, _ := c.decryptTicket(identity.label) + if plaintext == nil { + continue + } + sessionState := new(sessionStateTLS13) + if ok := sessionState.unmarshal(plaintext); !ok { + continue + } + + if hs.clientHello.earlyData { + if sessionState.maxEarlyData == 0 { + c.sendAlert(alertUnsupportedExtension) + return errors.New("tls: client sent unexpected early data") + } + + if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol && + c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 && + c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) { + hs.encryptedExtensions.earlyData = true + c.used0RTT = true + } + } + + createdAt := time.Unix(int64(sessionState.createdAt), 0) + if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { + continue + } + + // We don't check the obfuscated ticket age because it's affected by + // clock skew and it's only a freshness signal useful for shrinking the + // window for replay attacks, which don't affect us as we don't do 0-RTT. + + pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) + if pskSuite == nil || pskSuite.hash != hs.suite.hash { + continue + } + + // PSK connections don't re-establish client certificates, but carry + // them over in the session ticket. Ensure the presence of client certs + // in the ticket is consistent with the configured requirements. + sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 + needClientCerts := requiresClientCert(c.config.ClientAuth) + if needClientCerts && !sessionHasClientCerts { + continue + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + continue + } + + psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", + nil, hs.suite.hash.Size()) + hs.earlySecret = hs.suite.extract(psk, nil) + binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) + // Clone the transcript in case a HelloRetryRequest was recorded. + transcript := cloneHash(hs.transcript, hs.suite.hash) + if transcript == nil { + c.sendAlert(alertInternalError) + return errors.New("tls: internal error: failed to clone hash") + } + transcript.Write(hs.clientHello.marshalWithoutBinders()) + pskBinder := hs.suite.finishedHash(binderKey, transcript) + if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid PSK binder") + } + + c.didResume = true + if err := c.processCertsFromClient(sessionState.certificate); err != nil { + return err + } + + h := cloneHash(hs.transcript, hs.suite.hash) + h.Write(hs.clientHello.marshal()) + if hs.encryptedExtensions.earlyData { + clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h) + c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret) + if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil { + c.sendAlert(alertInternalError) + return err + } + } + + hs.hello.selectedIdentityPresent = true + hs.hello.selectedIdentity = uint16(i) + hs.usingPSK = true + return nil + } + + return nil +} + +// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// interfaces implemented by standard library hashes to clone the state of in +// to a new instance of h. It returns nil if the operation fails. +func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { + // Recreate the interface to avoid importing encoding. + type binaryMarshaler interface { + MarshalBinary() (data []byte, err error) + UnmarshalBinary(data []byte) error + } + marshaler, ok := in.(binaryMarshaler) + if !ok { + return nil + } + state, err := marshaler.MarshalBinary() + if err != nil { + return nil + } + out := h.New() + unmarshaler, ok := out.(binaryMarshaler) + if !ok { + return nil + } + if err := unmarshaler.UnmarshalBinary(state); err != nil { + return nil + } + return out +} + +func (hs *serverHandshakeStateTLS13) pickCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. + if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { + return c.sendAlert(alertMissingExtension) + } + + certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) + if err != nil { + if err == errNoCertificates { + c.sendAlert(alertUnrecognizedName) + } else { + c.sendAlert(alertInternalError) + } + return err + } + hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) + if err != nil { + // getCertificate returned a certificate that is unsupported or + // incompatible with the client's signature algorithms. + c.sendAlert(alertHandshakeFailure) + return err + } + hs.cert = certificate + + return nil +} + +// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility +// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. +func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { + if hs.sentDummyCCS { + return nil + } + hs.sentDummyCCS = true + + _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + return err +} + +func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { + c := hs.c + + // The first ClientHello gets double-hashed into the transcript upon a + // HelloRetryRequest. See RFC 8446, Section 4.4.1. + hs.transcript.Write(hs.clientHello.marshal()) + chHash := hs.transcript.Sum(nil) + hs.transcript.Reset() + hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) + hs.transcript.Write(chHash) + + helloRetryRequest := &serverHelloMsg{ + vers: hs.hello.vers, + random: helloRetryRequestRandom, + sessionId: hs.hello.sessionId, + cipherSuite: hs.hello.cipherSuite, + compressionMethod: hs.hello.compressionMethod, + supportedVersion: hs.hello.supportedVersion, + selectedGroup: selectedGroup, + } + + hs.transcript.Write(helloRetryRequest.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + clientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientHello, msg) + } + + if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client sent invalid key share in second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client indicated early data in second ClientHello") + } + + if illegalClientHelloChange(clientHello, hs.clientHello) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client illegally modified second ClientHello") + } + + if clientHello.earlyData { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client offered 0-RTT data in second ClientHello") + } + + hs.clientHello = clientHello + return nil +} + +// illegalClientHelloChange reports whether the two ClientHello messages are +// different, with the exception of the changes allowed before and after a +// HelloRetryRequest. See RFC 8446, Section 4.1.2. +func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { + if len(ch.supportedVersions) != len(ch1.supportedVersions) || + len(ch.cipherSuites) != len(ch1.cipherSuites) || + len(ch.supportedCurves) != len(ch1.supportedCurves) || + len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || + len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || + len(ch.alpnProtocols) != len(ch1.alpnProtocols) { + return true + } + for i := range ch.supportedVersions { + if ch.supportedVersions[i] != ch1.supportedVersions[i] { + return true + } + } + for i := range ch.cipherSuites { + if ch.cipherSuites[i] != ch1.cipherSuites[i] { + return true + } + } + for i := range ch.supportedCurves { + if ch.supportedCurves[i] != ch1.supportedCurves[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithms { + if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { + return true + } + } + for i := range ch.supportedSignatureAlgorithmsCert { + if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { + return true + } + } + for i := range ch.alpnProtocols { + if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { + return true + } + } + return ch.vers != ch1.vers || + !bytes.Equal(ch.random, ch1.random) || + !bytes.Equal(ch.sessionId, ch1.sessionId) || + !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || + ch.serverName != ch1.serverName || + ch.ocspStapling != ch1.ocspStapling || + !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || + ch.ticketSupported != ch1.ticketSupported || + !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || + ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || + !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || + ch.scts != ch1.scts || + !bytes.Equal(ch.cookie, ch1.cookie) || + !bytes.Equal(ch.pskModes, ch1.pskModes) +} + +func (hs *serverHandshakeStateTLS13) sendServerParameters() error { + c := hs.c + + hs.transcript.Write(hs.clientHello.marshal()) + hs.transcript.Write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } + + if err := hs.sendDummyChangeCipherSpec(); err != nil { + return err + } + + earlySecret := hs.earlySecret + if earlySecret == nil { + earlySecret = hs.suite.extract(nil, nil) + } + hs.handshakeSecret = hs.suite.extract(hs.sharedKey, + hs.suite.deriveSecret(earlySecret, "derived", nil)) + + clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, + clientHandshakeTrafficLabel, hs.transcript) + c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret) + c.in.setTrafficSecret(hs.suite, clientSecret) + serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, + serverHandshakeTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + if hs.alpnNegotiationErr != nil { + c.sendAlert(alertNoApplicationProtocol) + return hs.alpnNegotiationErr + } + if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil { + hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions) + } + + hs.transcript.Write(hs.encryptedExtensions.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.encryptedExtensions.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) requestClientCert() bool { + return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK +} + +func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { + c := hs.c + + // Only one of PSK and certificates are used at a time. + if hs.usingPSK { + return nil + } + + if hs.requestClientCert() { + // Request a client certificate + certReq := new(certificateRequestMsgTLS13) + certReq.ocspStapling = true + certReq.scts = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() + } + + hs.transcript.Write(certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } + } + + certMsg := new(certificateMsgTLS13) + + certMsg.certificate = *hs.cert + certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 + certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 + + hs.transcript.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + certVerifyMsg := new(certificateVerifyMsg) + certVerifyMsg.hasSignatureAlgorithm = true + certVerifyMsg.signatureAlgorithm = hs.sigAlg + + sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) + if err != nil { + return c.sendAlert(alertInternalError) + } + + signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + if err != nil { + public := hs.cert.PrivateKey.(crypto.Signer).Public() + if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && + rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS + c.sendAlert(alertHandshakeFailure) + } else { + c.sendAlert(alertInternalError) + } + return errors.New("tls: failed to sign handshake: " + err.Error()) + } + certVerifyMsg.signature = sig + + hs.transcript.Write(certVerifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) sendServerFinished() error { + c := hs.c + + finished := &finishedMsg{ + verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), + } + + hs.transcript.Write(finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } + + // Derive secrets that take context through the server Finished. + + hs.masterSecret = hs.suite.extract(nil, + hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) + + hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, + clientApplicationTrafficLabel, hs.transcript) + serverSecret := hs.suite.deriveSecret(hs.masterSecret, + serverApplicationTrafficLabel, hs.transcript) + c.out.exportKey(EncryptionApplication, hs.suite, serverSecret) + c.out.setTrafficSecret(hs.suite, serverSecret) + + err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) + + // If we did not request client certificates, at this point we can + // precompute the client finished and roll the transcript forward to send + // session tickets in our first flight. + if !hs.requestClientCert() { + if err := hs.sendSessionTickets(); err != nil { + return err + } + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { + if hs.c.config.SessionTicketsDisabled { + return false + } + + // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. + for _, pskMode := range hs.clientHello.pskModes { + if pskMode == pskModeDHE { + return true + } + } + return false +} + +func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { + c := hs.c + + hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) + finishedMsg := &finishedMsg{ + verifyData: hs.clientFinished, + } + hs.transcript.Write(finishedMsg.marshal()) + + if !hs.shouldSendSessionTickets() { + return nil + } + + c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, + resumptionLabel, hs.transcript) + + // Don't send session tickets when the alternative record layer is set. + // Instead, save the resumption secret on the Conn. + // Session tickets can then be generated by calling Conn.GetSessionTicket(). + if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil { + return nil + } + + m, err := hs.c.getSessionTicketMsg(nil) + if err != nil { + return err + } + + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientCertificate() error { + c := hs.c + + if !hs.requestClientCert() { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil + } + + // If we requested a client certificate, then the client must send a + // certificate message. If it's empty, no CertificateVerify is sent. + + msg, err := c.readHandshake() + if err != nil { + return err + } + + certMsg, ok := msg.(*certificateMsgTLS13) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.transcript.Write(certMsg.marshal()) + + if err := c.processCertsFromClient(certMsg.certificate); err != nil { + return err + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if len(certMsg.certificate.Certificate) != 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // See RFC 8446, Section 4.4.3. + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) + if err != nil { + return c.sendAlert(alertInternalError) + } + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: client certificate used with invalid signature algorithm") + } + signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, + sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } + + hs.transcript.Write(certVerify.marshal()) + } + + // If we waited until the client certificates to send session tickets, we + // are ready to do it now. + if err := hs.sendSessionTickets(); err != nil { + return err + } + + return nil +} + +func (hs *serverHandshakeStateTLS13) readClientFinished() error { + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return err + } + + finished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(finished, msg) + } + + if !hmac.Equal(hs.clientFinished, finished.verifyData) { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid client finished hash") + } + + c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) + c.in.setTrafficSecret(hs.suite, hs.trafficSecret) + + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/key_agreement.go b/vendor/github.com/marten-seemann/qtls-go1-19/key_agreement.go new file mode 100644 index 00000000..453a8dcf --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/key_agreement.go @@ -0,0 +1,357 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/md5" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "errors" + "fmt" + "io" +) + +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error + generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) +} + +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + +// rsaKeyAgreement implements the standard TLS key agreement where the client +// encrypts the pre-master secret to the server's public key. +type rsaKeyAgreement struct{} + +func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + return nil, nil +} + +func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) + if ciphertextLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + ciphertext := ckx.ciphertext[2:] + + priv, ok := cert.PrivateKey.(crypto.Decrypter) + if !ok { + return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") + } + // Perform constant time RSA PKCS #1 v1.5 decryption + preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) + if err != nil { + return nil, err + } + // We don't check the version number in the premaster secret. For one, + // by checking it, we would leak information about the validity of the + // encrypted pre-master secret. Secondly, it provides only a small + // benefit against a downgrade attack and some implementations send the + // wrong version anyway. See the discussion at the end of section + // 7.4.7.1 of RFC 4346. + return preMasterSecret, nil +} + +func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + return errors.New("tls: unexpected ServerKeyExchange") +} + +func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + preMasterSecret := make([]byte, 48) + preMasterSecret[0] = byte(clientHello.vers >> 8) + preMasterSecret[1] = byte(clientHello.vers) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, nil, err + } + + rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") + } + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, len(encrypted)+2) + ckx.ciphertext[0] = byte(len(encrypted) >> 8) + ckx.ciphertext[1] = byte(len(encrypted)) + copy(ckx.ciphertext[2:], encrypted) + return preMasterSecret, ckx, nil +} + +// sha1Hash calculates a SHA1 hash over the given byte slices. +func sha1Hash(slices [][]byte) []byte { + hsha1 := sha1.New() + for _, slice := range slices { + hsha1.Write(slice) + } + return hsha1.Sum(nil) +} + +// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the +// concatenation of an MD5 and SHA1 hash. +func md5SHA1Hash(slices [][]byte) []byte { + md5sha1 := make([]byte, md5.Size+sha1.Size) + hmd5 := md5.New() + for _, slice := range slices { + hmd5.Write(slice) + } + copy(md5sha1, hmd5.Sum(nil)) + copy(md5sha1[md5.Size:], sha1Hash(slices)) + return md5sha1 +} + +// hashForServerKeyExchange hashes the given slices and returns their digest +// using the given hash function (for >= TLS 1.2) or using a default based on +// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't +// do pre-hashing, it returns the concatenation of the slices. +func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { + if sigType == signatureEd25519 { + var signed []byte + for _, slice := range slices { + signed = append(signed, slice...) + } + return signed + } + if version >= VersionTLS12 { + h := hashFunc.New() + for _, slice := range slices { + h.Write(slice) + } + digest := h.Sum(nil) + return digest + } + if sigType == signatureECDSA { + return sha1Hash(slices) + } + return md5SHA1Hash(slices) +} + +// ecdheKeyAgreement implements a TLS key agreement where the server +// generates an ephemeral EC public/private key pair and signs it. The +// pre-master secret is then calculated using ECDH. The signature may +// be ECDSA, Ed25519 or RSA. +type ecdheKeyAgreement struct { + version uint16 + isRSA bool + params ecdheParameters + + // ckx and preMasterSecret are generated in processServerKeyExchange + // and returned in generateClientKeyExchange. + ckx *clientKeyExchangeMsg + preMasterSecret []byte +} + +func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + var curveID CurveID + for _, c := range clientHello.supportedCurves { + if config.supportsCurve(c) { + curveID = c + break + } + } + + if curveID == 0 { + return nil, errors.New("tls: no supported elliptic curves offered") + } + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, err + } + ka.params = params + + // See RFC 4492, Section 5.4. + ecdhePublic := params.PublicKey() + serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) + serverECDHEParams[0] = 3 // named curve + serverECDHEParams[1] = byte(curveID >> 8) + serverECDHEParams[2] = byte(curveID) + serverECDHEParams[3] = byte(len(ecdhePublic)) + copy(serverECDHEParams[4:], ecdhePublic) + + priv, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) + } + + var signatureAlgorithm SignatureScheme + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) + if err != nil { + return nil, err + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return nil, err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) + if err != nil { + return nil, err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) + + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err := priv.Sign(config.rand(), signed, signOpts) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } + + skx := new(serverKeyExchangeMsg) + sigAndHashLen := 0 + if ka.version >= VersionTLS12 { + sigAndHashLen = 2 + } + skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) + copy(skx.key, serverECDHEParams) + k := skx.key[len(serverECDHEParams):] + if ka.version >= VersionTLS12 { + k[0] = byte(signatureAlgorithm >> 8) + k[1] = byte(signatureAlgorithm) + k = k[2:] + } + k[0] = byte(len(sig) >> 8) + k[1] = byte(len(sig)) + copy(k[2:], sig) + + return skx, nil +} + +func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { + return nil, errClientKeyExchange + } + + preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) + if preMasterSecret == nil { + return nil, errClientKeyExchange + } + + return preMasterSecret, nil +} + +func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 4 { + return errServerKeyExchange + } + if skx.key[0] != 3 { // named curve + return errors.New("tls: server selected unsupported curve") + } + curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + + publicLen := int(skx.key[3]) + if publicLen+4 > len(skx.key) { + return errServerKeyExchange + } + serverECDHEParams := skx.key[:4+publicLen] + publicKey := serverECDHEParams[4:] + + sig := skx.key[4+publicLen:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return errors.New("tls: server selected unsupported curve") + } + + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return err + } + ka.params = params + + ka.preMasterSecret = params.SharedKey(publicKey) + if ka.preMasterSecret == nil { + return errServerKeyExchange + } + + ourPublicKey := params.PublicKey() + ka.ckx = new(clientKeyExchangeMsg) + ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) + ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) + copy(ka.ckx.ciphertext[1:], ourPublicKey) + + var sigType uint8 + var sigHash crypto.Hash + if ka.version >= VersionTLS12 { + signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) + sig = sig[2:] + if len(sig) < 2 { + return errServerKeyExchange + } + + if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { + return errors.New("tls: certificate used with invalid signature algorithm") + } + sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + if err != nil { + return err + } + } else { + sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) + if err != nil { + return err + } + } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return errServerKeyExchange + } + + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] + + signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) + if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } + return nil +} + +func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + if ka.ckx == nil { + return nil, nil, errors.New("tls: missing ServerKeyExchange message") + } + + return ka.preMasterSecret, ka.ckx, nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/key_schedule.go b/vendor/github.com/marten-seemann/qtls-go1-19/key_schedule.go new file mode 100644 index 00000000..da13904a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/key_schedule.go @@ -0,0 +1,199 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto/elliptic" + "crypto/hmac" + "errors" + "hash" + "io" + "math/big" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +// This file contains the functions necessary to compute the TLS 1.3 key +// schedule. See RFC 8446, Section 7. + +const ( + resumptionBinderLabel = "res binder" + clientHandshakeTrafficLabel = "c hs traffic" + serverHandshakeTrafficLabel = "s hs traffic" + clientApplicationTrafficLabel = "c ap traffic" + serverApplicationTrafficLabel = "s ap traffic" + exporterLabel = "exp master" + resumptionLabel = "res master" + trafficUpdateLabel = "traffic upd" +) + +// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("tls: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} + +// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { + if transcript == nil { + transcript = c.hash.New() + } + return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) +} + +// extract implements HKDF-Extract with the cipher suite hash. +func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { + if newSecret == nil { + newSecret = make([]byte, c.hash.Size()) + } + return hkdf.Extract(c.hash.New, newSecret, currentSecret) +} + +// nextTrafficSecret generates the next traffic secret, given the current one, +// according to RFC 8446, Section 7.2. +func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { + return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) +} + +// trafficKey generates traffic keys according to RFC 8446, Section 7.3. +func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { + key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) + iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) + return +} + +// finishedHash generates the Finished verify_data or PskBinderEntry according +// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey +// selection. +func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { + finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) + verifyData := hmac.New(c.hash.New, finishedKey) + verifyData.Write(transcript.Sum(nil)) + return verifyData.Sum(nil) +} + +// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to +// RFC 8446, Section 7.5. +func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { + expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) + return func(label string, context []byte, length int) ([]byte, error) { + secret := c.deriveSecret(expMasterSecret, label, nil) + h := c.hash.New() + h.Write(context) + return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil + } +} + +// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, +// according to RFC 8446, Section 4.2.8.2. +type ecdheParameters interface { + CurveID() CurveID + PublicKey() []byte + SharedKey(peerPublicKey []byte) []byte +} + +func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { + if curveID == X25519 { + privateKey := make([]byte, curve25519.ScalarSize) + if _, err := io.ReadFull(rand, privateKey); err != nil { + return nil, err + } + publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) + if err != nil { + return nil, err + } + return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil + } + + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, errors.New("tls: internal error: unsupported curve") + } + + p := &nistParameters{curveID: curveID} + var err error + p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) + if err != nil { + return nil, err + } + return p, nil +} + +func curveForCurveID(id CurveID) (elliptic.Curve, bool) { + switch id { + case CurveP256: + return elliptic.P256(), true + case CurveP384: + return elliptic.P384(), true + case CurveP521: + return elliptic.P521(), true + default: + return nil, false + } +} + +type nistParameters struct { + privateKey []byte + x, y *big.Int // public key + curveID CurveID +} + +func (p *nistParameters) CurveID() CurveID { + return p.curveID +} + +func (p *nistParameters) PublicKey() []byte { + curve, _ := curveForCurveID(p.curveID) + return elliptic.Marshal(curve, p.x, p.y) +} + +func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { + curve, _ := curveForCurveID(p.curveID) + // Unmarshal also checks whether the given point is on the curve. + x, y := elliptic.Unmarshal(curve, peerPublicKey) + if x == nil { + return nil + } + + xShared, _ := curve.ScalarMult(x, y, p.privateKey) + sharedKey := make([]byte, (curve.Params().BitSize+7)/8) + return xShared.FillBytes(sharedKey) +} + +type x25519Parameters struct { + privateKey []byte + publicKey []byte +} + +func (p *x25519Parameters) CurveID() CurveID { + return X25519 +} + +func (p *x25519Parameters) PublicKey() []byte { + return p.publicKey[:] +} + +func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { + sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) + if err != nil { + return nil + } + return sharedKey +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/notboring.go b/vendor/github.com/marten-seemann/qtls-go1-19/notboring.go new file mode 100644 index 00000000..f292e4f0 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/notboring.go @@ -0,0 +1,18 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +func needFIPS() bool { return false } + +func supportedSignatureAlgorithms() []SignatureScheme { + return defaultSupportedSignatureAlgorithms +} + +func fipsMinVersion(c *config) uint16 { panic("fipsMinVersion") } +func fipsMaxVersion(c *config) uint16 { panic("fipsMaxVersion") } +func fipsCurvePreferences(c *config) []CurveID { panic("fipsCurvePreferences") } +func fipsCipherSuites(c *config) []uint16 { panic("fipsCipherSuites") } + +var fipsSupportedSignatureAlgorithms []SignatureScheme diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/prf.go b/vendor/github.com/marten-seemann/qtls-go1-19/prf.go new file mode 100644 index 00000000..9eb0221a --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/prf.go @@ -0,0 +1,283 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" +) + +// Split a premaster secret in two as specified in RFC 4346, Section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return +} + +// pHash implements the P_hash function, as defined in RFC 4346, Section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + copy(result[j:], b) + j += len(b) + + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} + +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. +func prf10(result, secret, label, seed []byte) { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + + for i, b := range result2 { + result[i] ^= b + } +} + +// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + pHash(result, secret, labelAndSeed, hashFunc) + } +} + +const ( + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. +) + +var masterSecretLabel = []byte("master secret") +var keyExpansionLabel = []byte("key expansion") +var clientFinishedLabel = []byte("client finished") +var serverFinishedLabel = []byte("server finished") + +func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { + switch version { + case VersionTLS10, VersionTLS11: + return prf10, crypto.Hash(0) + case VersionTLS12: + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 + default: + panic("unknown version") + } +} + +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + prf, _ := prfAndHashForVersion(version, suite) + return prf +} + +// masterFromPreMasterSecret generates the master secret from the pre-master +// secret. See RFC 5246, Section 8.1. +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { + seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) + return masterSecret +} + +// keysFromMasterSecret generates the connection keys from the master +// secret, given the lengths of the MAC key, cipher key and IV, as defined in +// RFC 2246, Section 6.3. +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { + seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) + seed = append(seed, serverRandom...) + seed = append(seed, clientRandom...) + + n := 2*macLen + 2*keyLen + 2*ivLen + keyMaterial := make([]byte, n) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) + clientMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + serverMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + clientKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + serverKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + clientIV = keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + serverIV = keyMaterial[:ivLen] + return +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + var buffer []byte + if version >= VersionTLS12 { + buffer = []byte{} + } + + prf, hash := prfAndHashForVersion(version, cipherSuite) + if hash != 0 { + return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} + } + + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +func (h *finishedHash) Write(msg []byte) (n int, err error) { + h.client.Write(msg) + h.server.Write(msg) + + if h.version < VersionTLS12 { + h.clientMD5.Write(msg) + h.serverMD5.Write(msg) + } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + + return len(msg), nil +} + +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + +// clientSum returns the contents of the verify_data member of a client's +// Finished message. +func (h finishedHash) clientSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) + return out +} + +// serverSum returns the contents of the verify_data member of a server's +// Finished message. +func (h finishedHash) serverSum(masterSecret []byte) []byte { + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) + return out +} + +// hashForClientCertificate returns the handshake messages so far, pre-hashed if +// necessary, suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { + if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { + panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") + } + + if sigType == signatureEd25519 { + return h.buffer + } + + if h.version >= VersionTLS12 { + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil) + } + + if sigType == signatureECDSA { + return h.server.Sum(nil) + } + + return h.Sum() +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil +} + +// noExportedKeyingMaterial is used as a value of +// ConnectionState.ekm when renegotiation is enabled and thus +// we wish to fail all key-material export requests. +func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") +} + +// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. +func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { + return func(label string, context []byte, length int) ([]byte, error) { + switch label { + case "client finished", "server finished", "master secret", "key expansion": + // These values are reserved and may not be used. + return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) + } + + seedLen := len(serverRandom) + len(clientRandom) + if context != nil { + seedLen += 2 + len(context) + } + seed := make([]byte, 0, seedLen) + + seed = append(seed, clientRandom...) + seed = append(seed, serverRandom...) + + if context != nil { + if len(context) >= 1<<16 { + return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") + } + seed = append(seed, byte(len(context)>>8), byte(len(context))) + seed = append(seed, context...) + } + + keyMaterial := make([]byte, length) + prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) + return keyMaterial, nil + } +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/ticket.go b/vendor/github.com/marten-seemann/qtls-go1-19/ticket.go new file mode 100644 index 00000000..81e8a52e --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/ticket.go @@ -0,0 +1,274 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/binary" + "errors" + "io" + "time" + + "golang.org/x/crypto/cryptobyte" +) + +// sessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type sessionState struct { + vers uint16 + cipherSuite uint16 + createdAt uint64 + masterSecret []byte // opaque master_secret<1..2^16-1>; + // struct { opaque certificate<1..2^24-1> } Certificate; + certificates [][]byte // Certificate certificate_list<0..2^24-1>; + + // usedOldKey is true if the ticket from which this session came from + // was encrypted with an older key and thus should be refreshed. + usedOldKey bool +} + +func (m *sessionState) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(m.vers) + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.masterSecret) + }) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for _, cert := range m.certificates { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + } + }) + return b.BytesOrPanic() +} + +func (m *sessionState) unmarshal(data []byte) bool { + *m = sessionState{usedOldKey: m.usedOldKey} + s := cryptobyte.String(data) + if ok := s.ReadUint16(&m.vers) && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint16LengthPrefixed(&s, &m.masterSecret) && + len(m.masterSecret) != 0; !ok { + return false + } + var certList cryptobyte.String + if !s.ReadUint24LengthPrefixed(&certList) { + return false + } + for !certList.Empty() { + var cert []byte + if !readUint24LengthPrefixed(&certList, &cert) { + return false + } + m.certificates = append(m.certificates, cert) + } + return s.Empty() +} + +// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first +// version (revision = 0) doesn't carry any of the information needed for 0-RTT +// validation and the nonce is always empty. +// version (revision = 1) carries the max_early_data_size sent in the ticket. +// version (revision = 2) carries the ALPN sent in the ticket. +type sessionStateTLS13 struct { + // uint8 version = 0x0304; + // uint8 revision = 2; + cipherSuite uint16 + createdAt uint64 + resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; + certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; + maxEarlyData uint32 + alpn string + + appData []byte +} + +func (m *sessionStateTLS13) marshal() []byte { + var b cryptobyte.Builder + b.AddUint16(VersionTLS13) + b.AddUint8(2) // revision + b.AddUint16(m.cipherSuite) + addUint64(&b, m.createdAt) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.resumptionSecret) + }) + marshalCertificate(&b, m.certificate) + b.AddUint32(m.maxEarlyData) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpn)) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.appData) + }) + return b.BytesOrPanic() +} + +func (m *sessionStateTLS13) unmarshal(data []byte) bool { + *m = sessionStateTLS13{} + s := cryptobyte.String(data) + var version uint16 + var revision uint8 + var alpn []byte + ret := s.ReadUint16(&version) && + version == VersionTLS13 && + s.ReadUint8(&revision) && + revision == 2 && + s.ReadUint16(&m.cipherSuite) && + readUint64(&s, &m.createdAt) && + readUint8LengthPrefixed(&s, &m.resumptionSecret) && + len(m.resumptionSecret) != 0 && + unmarshalCertificate(&s, &m.certificate) && + s.ReadUint32(&m.maxEarlyData) && + readUint8LengthPrefixed(&s, &alpn) && + readUint16LengthPrefixed(&s, &m.appData) && + s.Empty() + m.alpn = string(alpn) + return ret +} + +func (c *Conn) encryptTicket(state []byte) ([]byte, error) { + if len(c.ticketKeys) == 0 { + return nil, errors.New("tls: internal error: session ticket keys unavailable") + } + + encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + return nil, err + } + key := c.ticketKeys[0] + copy(keyName, key.keyName[:]) + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { + if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { + return nil, false + } + + keyName := encrypted[:ticketKeyNameLen] + iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] + + keyIndex := -1 + for i, candidateKey := range c.ticketKeys { + if bytes.Equal(keyName, candidateKey.keyName[:]) { + keyIndex = i + break + } + } + if keyIndex == -1 { + return nil, false + } + key := &c.ticketKeys[keyIndex] + + mac := hmac.New(sha256.New, key.hmacKey[:]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + expected := mac.Sum(nil) + + if subtle.ConstantTimeCompare(macBytes, expected) != 1 { + return nil, false + } + + block, err := aes.NewCipher(key.aesKey[:]) + if err != nil { + return nil, false + } + plaintext = make([]byte, len(ciphertext)) + cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) + + return plaintext, keyIndex > 0 +} + +func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) { + m := new(newSessionTicketMsgTLS13) + + var certsFromClient [][]byte + for _, cert := range c.peerCertificates { + certsFromClient = append(certsFromClient, cert.Raw) + } + state := sessionStateTLS13{ + cipherSuite: c.cipherSuite, + createdAt: uint64(c.config.time().Unix()), + resumptionSecret: c.resumptionSecret, + certificate: Certificate{ + Certificate: certsFromClient, + OCSPStaple: c.ocspResponse, + SignedCertificateTimestamps: c.scts, + }, + appData: appData, + alpn: c.clientProtocol, + } + if c.extraConfig != nil { + state.maxEarlyData = c.extraConfig.MaxEarlyData + } + var err error + m.label, err = c.encryptTicket(state.marshal()) + if err != nil { + return nil, err + } + m.lifetime = uint32(maxSessionTicketLifetime / time.Second) + + // ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1 + // The value is not stored anywhere; we never need to check the ticket age + // because 0-RTT is not supported. + ageAdd := make([]byte, 4) + _, err = c.config.rand().Read(ageAdd) + if err != nil { + return nil, err + } + m.ageAdd = binary.LittleEndian.Uint32(ageAdd) + + // ticket_nonce, which must be unique per connection, is always left at + // zero because we only ever send one ticket per connection. + + if c.extraConfig != nil { + m.maxEarlyData = c.extraConfig.MaxEarlyData + } + return m, nil +} + +// GetSessionTicket generates a new session ticket. +// It should only be called after the handshake completes. +// It can only be used for servers, and only if the alternative record layer is set. +// The ticket may be nil if config.SessionTicketsDisabled is set, +// or if the client isn't able to receive session tickets. +func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { + if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.") + } + if c.config.SessionTicketsDisabled { + return nil, nil + } + + m, err := c.getSessionTicketMsg(appData) + if err != nil { + return nil, err + } + return m.marshal(), nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/tls.go b/vendor/github.com/marten-seemann/qtls-go1-19/tls.go new file mode 100644 index 00000000..42207c23 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/tls.go @@ -0,0 +1,362 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// package qtls partially implements TLS 1.2, as specified in RFC 5246, +// and TLS 1.3, as specified in RFC 8446. +package qtls + +// BUG(agl): The crypto/tls package only implements some countermeasures +// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 +// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "net" + "os" + "strings" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + c := &Conn{ + conn: conn, + config: fromConfig(config), + extraConfig: extraConfig, + } + c.handshakeFn = c.serverHandshake + return c +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerName or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + c := &Conn{ + conn: conn, + config: fromConfig(config), + extraConfig: extraConfig, + isClient: true, + } + c.handshakeFn = c.clientHandshake + return c +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type listener struct { + net.Listener + config *Config + extraConfig *ExtraConfig +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection is of type *Conn. +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return Server(c, l.config, l.extraConfig), nil +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener { + l := new(listener) + l.Listener = inner + l.config = config + l.extraConfig = extraConfig + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must include +// at least one certificate or else set GetCertificate. +func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) { + if config == nil || len(config.Certificates) == 0 && + config.GetCertificate == nil && config.GetConfigForClient == nil { + return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config, extraConfig), nil +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +// +// DialWithDialer uses context.Background internally; to specify the context, +// use Dialer.DialContext with NetDialer set to the desired dialer. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + return dial(context.Background(), dialer, network, addr, config, extraConfig) +} + +func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + if netDialer.Timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) + defer cancel() + } + + if !netDialer.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) + defer cancel() + } + + rawConn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = defaultConfig() + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := config.Clone() + c.ServerName = hostname + config = c + } + + conn := Client(rawConn, config, extraConfig) + if err := conn.HandshakeContext(ctx); err != nil { + rawConn.Close() + return nil, err + } + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig) +} + +// Dialer dials TLS connections given a configuration and a Dialer for the +// underlying connection. +type Dialer struct { + // NetDialer is the optional dialer to use for the TLS connections' + // underlying TCP connections. + // A nil NetDialer is equivalent to the net.Dialer zero value. + NetDialer *net.Dialer + + // Config is the TLS configuration to use for new connections. + // A nil configuration is equivalent to the zero + // configuration; see the documentation of Config for the + // defaults. + Config *Config + + ExtraConfig *ExtraConfig +} + +// Dial connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The returned Conn, if any, will always be of type *Conn. +// +// Dial uses context.Background internally; to specify the context, +// use DialContext. +func (d *Dialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} + +func (d *Dialer) netDialer() *net.Dialer { + if d.NetDialer != nil { + return d.NetDialer + } + return new(net.Dialer) +} + +// DialContext connects to the given network address and initiates a TLS +// handshake, returning the resulting TLS connection. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// The returned Conn, if any, will always be of type *Conn. +func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig) + if err != nil { + // Don't return c (a typed nil) in an interface. + return nil, err + } + return c, nil +} + +// LoadX509KeyPair reads and parses a public/private key pair from a pair +// of files. The files must contain PEM encoded data. The certificate file +// may contain intermediate certificates following the leaf certificate to +// form a certificate chain. On successful return, Certificate.Leaf will +// be nil because the parsed form of the certificate is not retained. +func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { + certPEMBlock, err := os.ReadFile(certFile) + if err != nil { + return Certificate{}, err + } + keyPEMBlock, err := os.ReadFile(keyFile) + if err != nil { + return Certificate{}, err + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} + +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. On successful return, Certificate.Leaf will be nil because +// the parsed form of the certificate is not retained. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + fail := func(err error) (Certificate, error) { return Certificate{}, err } + + var cert Certificate + var skippedBlockTypes []string + for { + var certDERBlock *pem.Block + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } else { + skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) + } + } + + if len(cert.Certificate) == 0 { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in certificate input")) + } + if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { + return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) + } + return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + + skippedBlockTypes = skippedBlockTypes[:0] + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) + if keyDERBlock == nil { + if len(skippedBlockTypes) == 0 { + return fail(errors.New("tls: failed to find any PEM data in key input")) + } + if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { + return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) + } + return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) + } + + // We don't need to parse the public key for TLS, but we so do anyway + // to check that it looks sane and matches the private key. + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fail(err) + } + + cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) + if err != nil { + return fail(err) + } + + switch pub := x509Cert.PublicKey.(type) { + case *rsa.PublicKey: + priv, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.N.Cmp(priv.N) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case *ecdsa.PublicKey: + priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + return fail(errors.New("tls: private key does not match public key")) + } + case ed25519.PublicKey: + priv, ok := cert.PrivateKey.(ed25519.PrivateKey) + if !ok { + return fail(errors.New("tls: private key type does not match public key type")) + } + if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { + return fail(errors.New("tls: private key does not match public key")) + } + default: + return fail(errors.New("tls: unknown public key algorithm")) + } + + return cert, nil +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("tls: failed to parse private key") +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/unsafe.go b/vendor/github.com/marten-seemann/qtls-go1-19/unsafe.go new file mode 100644 index 00000000..55fa01b3 --- /dev/null +++ b/vendor/github.com/marten-seemann/qtls-go1-19/unsafe.go @@ -0,0 +1,96 @@ +package qtls + +import ( + "crypto/tls" + "reflect" + "unsafe" +) + +func init() { + if !structsEqual(&tls.ConnectionState{}, &connectionState{}) { + panic("qtls.ConnectionState doesn't match") + } + if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) { + panic("qtls.ClientSessionState doesn't match") + } + if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) { + panic("qtls.CertificateRequestInfo doesn't match") + } + if !structsEqual(&tls.Config{}, &config{}) { + panic("qtls.Config doesn't match") + } + if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { + panic("qtls.ClientHelloInfo doesn't match") + } +} + +func toConnectionState(c connectionState) ConnectionState { + return *(*ConnectionState)(unsafe.Pointer(&c)) +} + +func toClientSessionState(s *clientSessionState) *ClientSessionState { + return (*ClientSessionState)(unsafe.Pointer(s)) +} + +func fromClientSessionState(s *ClientSessionState) *clientSessionState { + return (*clientSessionState)(unsafe.Pointer(s)) +} + +func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo { + return (*CertificateRequestInfo)(unsafe.Pointer(i)) +} + +func toConfig(c *config) *Config { + return (*Config)(unsafe.Pointer(c)) +} + +func fromConfig(c *Config) *config { + return (*config)(unsafe.Pointer(c)) +} + +func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo { + return (*ClientHelloInfo)(unsafe.Pointer(chi)) +} + +func structsEqual(a, b interface{}) bool { + return compare(reflect.ValueOf(a), reflect.ValueOf(b)) +} + +func compare(a, b reflect.Value) bool { + sa := a.Elem() + sb := b.Elem() + if sa.NumField() != sb.NumField() { + return false + } + for i := 0; i < sa.NumField(); i++ { + fa := sa.Type().Field(i) + fb := sb.Type().Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + if fa.Type.Kind() != fb.Type.Kind() { + return false + } + if fa.Type.Kind() == reflect.Slice { + if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) { + return false + } + continue + } + return false + } + } + return true +} + +func compareStruct(a, b reflect.Type) bool { + if a.NumField() != b.NumField() { + return false + } + for i := 0; i < a.NumField(); i++ { + fa := a.Field(i) + fb := b.Field(i) + if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { + return false + } + } + return true +} diff --git a/vendor/github.com/nxadm/tail/.gitignore b/vendor/github.com/nxadm/tail/.gitignore new file mode 100644 index 00000000..35d9351d --- /dev/null +++ b/vendor/github.com/nxadm/tail/.gitignore @@ -0,0 +1,3 @@ +.idea/ +.test/ +examples/_* \ No newline at end of file diff --git a/vendor/github.com/nxadm/tail/CHANGES.md b/vendor/github.com/nxadm/tail/CHANGES.md new file mode 100644 index 00000000..224e54b4 --- /dev/null +++ b/vendor/github.com/nxadm/tail/CHANGES.md @@ -0,0 +1,56 @@ +# Version v1.4.7-v1.4.8 +* Documentation updates. +* Small linter cleanups. +* Added example in test. + +# Version v1.4.6 + +* Document the usage of Cleanup when re-reading a file (thanks to @lesovsky) for issue #18. +* Add example directories with example and tests for issues. + +# Version v1.4.4-v1.4.5 + +* Fix of checksum problem because of forced tag. No changes to the code. + +# Version v1.4.1 + +* Incorporated PR 162 by by Mohammed902: "Simplify non-Windows build tag". + +# Version v1.4.0 + +* Incorporated PR 9 by mschneider82: "Added seekinfo to Tail". + +# Version v1.3.1 + +* Incorporated PR 7: "Fix deadlock when stopping on non-empty file/buffer", +fixes upstream issue 93. + + +# Version v1.3.0 + +* Incorporated changes of unmerged upstream PR 149 by mezzi: "added line num +to Line struct". + +# Version v1.2.1 + +* Incorporated changes of unmerged upstream PR 128 by jadekler: "Compile-able +code in readme". +* Incorporated changes of unmerged upstream PR 130 by fgeller: "small change +to comment wording". +* Incorporated changes of unmerged upstream PR 133 by sm3142: "removed +spurious newlines from log messages". + +# Version v1.2.0 + +* Incorporated changes of unmerged upstream PR 126 by Code-Hex: "Solved the + problem for never return the last line if it's not followed by a newline". +* Incorporated changes of unmerged upstream PR 131 by StoicPerlman: "Remove +deprecated os.SEEK consts". The changes bumped the minimal supported Go +release to 1.9. + +# Version v1.1.0 + +* migration to go modules. +* release of master branch of the dormant upstream, because it contains +fixes and improvement no present in the tagged release. + diff --git a/vendor/github.com/nxadm/tail/Dockerfile b/vendor/github.com/nxadm/tail/Dockerfile new file mode 100644 index 00000000..d9633891 --- /dev/null +++ b/vendor/github.com/nxadm/tail/Dockerfile @@ -0,0 +1,19 @@ +FROM golang + +RUN mkdir -p $GOPATH/src/github.com/nxadm/tail/ +ADD . $GOPATH/src/github.com/nxadm/tail/ + +# expecting to fetch dependencies successfully. +RUN go get -v github.com/nxadm/tail + +# expecting to run the test successfully. +RUN go test -v github.com/nxadm/tail + +# expecting to install successfully +RUN go install -v github.com/nxadm/tail +RUN go install -v github.com/nxadm/tail/cmd/gotail + +RUN $GOPATH/bin/gotail -h || true + +ENV PATH $GOPATH/bin:$PATH +CMD ["gotail"] diff --git a/vendor/github.com/nxadm/tail/LICENSE b/vendor/github.com/nxadm/tail/LICENSE new file mode 100644 index 00000000..818d802a --- /dev/null +++ b/vendor/github.com/nxadm/tail/LICENSE @@ -0,0 +1,21 @@ +# The MIT License (MIT) + +# © Copyright 2015 Hewlett Packard Enterprise Development LP +Copyright (c) 2014 ActiveState + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/nxadm/tail/README.md b/vendor/github.com/nxadm/tail/README.md new file mode 100644 index 00000000..f47939c7 --- /dev/null +++ b/vendor/github.com/nxadm/tail/README.md @@ -0,0 +1,44 @@ +![ci](https://github.com/nxadm/tail/workflows/ci/badge.svg)[![Go Reference](https://pkg.go.dev/badge/github.com/nxadm/tail.svg)](https://pkg.go.dev/github.com/nxadm/tail) + +# tail functionality in Go + +nxadm/tail provides a Go library that emulates the features of the BSD `tail` +program. The library comes with full support for truncation/move detection as +it is designed to work with log rotation tools. The library works on all +operating systems supported by Go, including POSIX systems like Linux and +*BSD, and MS Windows. Go 1.9 is the oldest compiler release supported. + +A simple example: + +```Go +// Create a tail +t, err := tail.TailFile( + "/var/log/nginx.log", tail.Config{Follow: true, ReOpen: true}) +if err != nil { + panic(err) +} + +// Print the text of each received line +for line := range t.Lines { + fmt.Println(line.Text) +} +``` + +See [API documentation](https://pkg.go.dev/github.com/nxadm/tail). + +## Installing + + go get github.com/nxadm/tail/... + +## History + +This project is an active, drop-in replacement for the +[abandoned](https://en.wikipedia.org/wiki/HPE_Helion) Go tail library at +[hpcloud](https://github.com/hpcloud/tail). Next to +[addressing open issues/PRs of the original project](https://github.com/nxadm/tail/issues/6), +nxadm/tail continues the development by keeping up to date with the Go toolchain +(e.g. go modules) and dependencies, completing the documentation, adding features +and fixing bugs. + +## Examples +Examples, e.g. used to debug an issue, are kept in the [examples directory](/examples). \ No newline at end of file diff --git a/vendor/github.com/nxadm/tail/ratelimiter/Licence b/vendor/github.com/nxadm/tail/ratelimiter/Licence new file mode 100644 index 00000000..434aab19 --- /dev/null +++ b/vendor/github.com/nxadm/tail/ratelimiter/Licence @@ -0,0 +1,7 @@ +Copyright (C) 2013 99designs + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/nxadm/tail/ratelimiter/leakybucket.go b/vendor/github.com/nxadm/tail/ratelimiter/leakybucket.go new file mode 100644 index 00000000..358b69e7 --- /dev/null +++ b/vendor/github.com/nxadm/tail/ratelimiter/leakybucket.go @@ -0,0 +1,97 @@ +// Package ratelimiter implements the Leaky Bucket ratelimiting algorithm with memcached and in-memory backends. +package ratelimiter + +import ( + "time" +) + +type LeakyBucket struct { + Size uint16 + Fill float64 + LeakInterval time.Duration // time.Duration for 1 unit of size to leak + Lastupdate time.Time + Now func() time.Time +} + +func NewLeakyBucket(size uint16, leakInterval time.Duration) *LeakyBucket { + bucket := LeakyBucket{ + Size: size, + Fill: 0, + LeakInterval: leakInterval, + Now: time.Now, + Lastupdate: time.Now(), + } + + return &bucket +} + +func (b *LeakyBucket) updateFill() { + now := b.Now() + if b.Fill > 0 { + elapsed := now.Sub(b.Lastupdate) + + b.Fill -= float64(elapsed) / float64(b.LeakInterval) + if b.Fill < 0 { + b.Fill = 0 + } + } + b.Lastupdate = now +} + +func (b *LeakyBucket) Pour(amount uint16) bool { + b.updateFill() + + var newfill float64 = b.Fill + float64(amount) + + if newfill > float64(b.Size) { + return false + } + + b.Fill = newfill + + return true +} + +// The time at which this bucket will be completely drained +func (b *LeakyBucket) DrainedAt() time.Time { + return b.Lastupdate.Add(time.Duration(b.Fill * float64(b.LeakInterval))) +} + +// The duration until this bucket is completely drained +func (b *LeakyBucket) TimeToDrain() time.Duration { + return b.DrainedAt().Sub(b.Now()) +} + +func (b *LeakyBucket) TimeSinceLastUpdate() time.Duration { + return b.Now().Sub(b.Lastupdate) +} + +type LeakyBucketSer struct { + Size uint16 + Fill float64 + LeakInterval time.Duration // time.Duration for 1 unit of size to leak + Lastupdate time.Time +} + +func (b *LeakyBucket) Serialise() *LeakyBucketSer { + bucket := LeakyBucketSer{ + Size: b.Size, + Fill: b.Fill, + LeakInterval: b.LeakInterval, + Lastupdate: b.Lastupdate, + } + + return &bucket +} + +func (b *LeakyBucketSer) DeSerialise() *LeakyBucket { + bucket := LeakyBucket{ + Size: b.Size, + Fill: b.Fill, + LeakInterval: b.LeakInterval, + Lastupdate: b.Lastupdate, + Now: time.Now, + } + + return &bucket +} diff --git a/vendor/github.com/nxadm/tail/ratelimiter/memory.go b/vendor/github.com/nxadm/tail/ratelimiter/memory.go new file mode 100644 index 00000000..bf3c2131 --- /dev/null +++ b/vendor/github.com/nxadm/tail/ratelimiter/memory.go @@ -0,0 +1,60 @@ +package ratelimiter + +import ( + "errors" + "time" +) + +const ( + GC_SIZE int = 100 + GC_PERIOD time.Duration = 60 * time.Second +) + +type Memory struct { + store map[string]LeakyBucket + lastGCCollected time.Time +} + +func NewMemory() *Memory { + m := new(Memory) + m.store = make(map[string]LeakyBucket) + m.lastGCCollected = time.Now() + return m +} + +func (m *Memory) GetBucketFor(key string) (*LeakyBucket, error) { + + bucket, ok := m.store[key] + if !ok { + return nil, errors.New("miss") + } + + return &bucket, nil +} + +func (m *Memory) SetBucketFor(key string, bucket LeakyBucket) error { + + if len(m.store) > GC_SIZE { + m.GarbageCollect() + } + + m.store[key] = bucket + + return nil +} + +func (m *Memory) GarbageCollect() { + now := time.Now() + + // rate limit GC to once per minute + if now.Unix() >= m.lastGCCollected.Add(GC_PERIOD).Unix() { + for key, bucket := range m.store { + // if the bucket is drained, then GC + if bucket.DrainedAt().Unix() < now.Unix() { + delete(m.store, key) + } + } + + m.lastGCCollected = now + } +} diff --git a/vendor/github.com/nxadm/tail/ratelimiter/storage.go b/vendor/github.com/nxadm/tail/ratelimiter/storage.go new file mode 100644 index 00000000..89b2fe88 --- /dev/null +++ b/vendor/github.com/nxadm/tail/ratelimiter/storage.go @@ -0,0 +1,6 @@ +package ratelimiter + +type Storage interface { + GetBucketFor(string) (*LeakyBucket, error) + SetBucketFor(string, LeakyBucket) error +} diff --git a/vendor/github.com/nxadm/tail/tail.go b/vendor/github.com/nxadm/tail/tail.go new file mode 100644 index 00000000..37ea4411 --- /dev/null +++ b/vendor/github.com/nxadm/tail/tail.go @@ -0,0 +1,455 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// Copyright (c) 2015 HPE Software Inc. All rights reserved. +// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. + +//nxadm/tail provides a Go library that emulates the features of the BSD `tail` +//program. The library comes with full support for truncation/move detection as +//it is designed to work with log rotation tools. The library works on all +//operating systems supported by Go, including POSIX systems like Linux and +//*BSD, and MS Windows. Go 1.9 is the oldest compiler release supported. +package tail + +import ( + "bufio" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "strings" + "sync" + "time" + + "github.com/nxadm/tail/ratelimiter" + "github.com/nxadm/tail/util" + "github.com/nxadm/tail/watch" + "gopkg.in/tomb.v1" +) + +var ( + // ErrStop is returned when the tail of a file has been marked to be stopped. + ErrStop = errors.New("tail should now stop") +) + +type Line struct { + Text string // The contents of the file + Num int // The line number + SeekInfo SeekInfo // SeekInfo + Time time.Time // Present time + Err error // Error from tail +} + +// Deprecated: this function is no longer used internally and it has little of no +// use in the API. As such, it will be removed from the API in a future major +// release. +// +// NewLine returns a * pointer to a Line struct. +func NewLine(text string, lineNum int) *Line { + return &Line{text, lineNum, SeekInfo{}, time.Now(), nil} +} + +// SeekInfo represents arguments to io.Seek. See: https://golang.org/pkg/io/#SectionReader.Seek +type SeekInfo struct { + Offset int64 + Whence int +} + +type logger interface { + Fatal(v ...interface{}) + Fatalf(format string, v ...interface{}) + Fatalln(v ...interface{}) + Panic(v ...interface{}) + Panicf(format string, v ...interface{}) + Panicln(v ...interface{}) + Print(v ...interface{}) + Printf(format string, v ...interface{}) + Println(v ...interface{}) +} + +// Config is used to specify how a file must be tailed. +type Config struct { + // File-specifc + Location *SeekInfo // Tail from this location. If nil, start at the beginning of the file + ReOpen bool // Reopen recreated files (tail -F) + MustExist bool // Fail early if the file does not exist + Poll bool // Poll for file changes instead of using the default inotify + Pipe bool // The file is a named pipe (mkfifo) + + // Generic IO + Follow bool // Continue looking for new lines (tail -f) + MaxLineSize int // If non-zero, split longer lines into multiple lines + + // Optionally, use a ratelimiter (e.g. created by the ratelimiter/NewLeakyBucket function) + RateLimiter *ratelimiter.LeakyBucket + + // Optionally use a Logger. When nil, the Logger is set to tail.DefaultLogger. + // To disable logging, set it to tail.DiscardingLogger + Logger logger +} + +type Tail struct { + Filename string // The filename + Lines chan *Line // A consumable channel of *Line + Config // Tail.Configuration + + file *os.File + reader *bufio.Reader + lineNum int + + watcher watch.FileWatcher + changes *watch.FileChanges + + tomb.Tomb // provides: Done, Kill, Dying + + lk sync.Mutex +} + +var ( + // DefaultLogger logs to os.Stderr and it is used when Config.Logger == nil + DefaultLogger = log.New(os.Stderr, "", log.LstdFlags) + // DiscardingLogger can be used to disable logging output + DiscardingLogger = log.New(ioutil.Discard, "", 0) +) + +// TailFile begins tailing the file. And returns a pointer to a Tail struct +// and an error. An output stream is made available via the Tail.Lines +// channel (e.g. to be looped and printed). To handle errors during tailing, +// after finishing reading from the Lines channel, invoke the `Wait` or `Err` +// method on the returned *Tail. +func TailFile(filename string, config Config) (*Tail, error) { + if config.ReOpen && !config.Follow { + util.Fatal("cannot set ReOpen without Follow.") + } + + t := &Tail{ + Filename: filename, + Lines: make(chan *Line), + Config: config, + } + + // when Logger was not specified in config, use default logger + if t.Logger == nil { + t.Logger = DefaultLogger + } + + if t.Poll { + t.watcher = watch.NewPollingFileWatcher(filename) + } else { + t.watcher = watch.NewInotifyFileWatcher(filename) + } + + if t.MustExist { + var err error + t.file, err = OpenFile(t.Filename) + if err != nil { + return nil, err + } + } + + go t.tailFileSync() + + return t, nil +} + +// Tell returns the file's current position, like stdio's ftell() and an error. +// Beware that this value may not be completely accurate because one line from +// the chan(tail.Lines) may have been read already. +func (tail *Tail) Tell() (offset int64, err error) { + if tail.file == nil { + return + } + offset, err = tail.file.Seek(0, io.SeekCurrent) + if err != nil { + return + } + + tail.lk.Lock() + defer tail.lk.Unlock() + if tail.reader == nil { + return + } + + offset -= int64(tail.reader.Buffered()) + return +} + +// Stop stops the tailing activity. +func (tail *Tail) Stop() error { + tail.Kill(nil) + return tail.Wait() +} + +// StopAtEOF stops tailing as soon as the end of the file is reached. The function +// returns an error, +func (tail *Tail) StopAtEOF() error { + tail.Kill(errStopAtEOF) + return tail.Wait() +} + +var errStopAtEOF = errors.New("tail: stop at eof") + +func (tail *Tail) close() { + close(tail.Lines) + tail.closeFile() +} + +func (tail *Tail) closeFile() { + if tail.file != nil { + tail.file.Close() + tail.file = nil + } +} + +func (tail *Tail) reopen() error { + tail.closeFile() + tail.lineNum = 0 + for { + var err error + tail.file, err = OpenFile(tail.Filename) + if err != nil { + if os.IsNotExist(err) { + tail.Logger.Printf("Waiting for %s to appear...", tail.Filename) + if err := tail.watcher.BlockUntilExists(&tail.Tomb); err != nil { + if err == tomb.ErrDying { + return err + } + return fmt.Errorf("Failed to detect creation of %s: %s", tail.Filename, err) + } + continue + } + return fmt.Errorf("Unable to open file %s: %s", tail.Filename, err) + } + break + } + return nil +} + +func (tail *Tail) readLine() (string, error) { + tail.lk.Lock() + line, err := tail.reader.ReadString('\n') + tail.lk.Unlock() + if err != nil { + // Note ReadString "returns the data read before the error" in + // case of an error, including EOF, so we return it as is. The + // caller is expected to process it if err is EOF. + return line, err + } + + line = strings.TrimRight(line, "\n") + + return line, err +} + +func (tail *Tail) tailFileSync() { + defer tail.Done() + defer tail.close() + + if !tail.MustExist { + // deferred first open. + err := tail.reopen() + if err != nil { + if err != tomb.ErrDying { + tail.Kill(err) + } + return + } + } + + // Seek to requested location on first open of the file. + if tail.Location != nil { + _, err := tail.file.Seek(tail.Location.Offset, tail.Location.Whence) + if err != nil { + tail.Killf("Seek error on %s: %s", tail.Filename, err) + return + } + } + + tail.openReader() + + // Read line by line. + for { + // do not seek in named pipes + if !tail.Pipe { + // grab the position in case we need to back up in the event of a half-line + if _, err := tail.Tell(); err != nil { + tail.Kill(err) + return + } + } + + line, err := tail.readLine() + + // Process `line` even if err is EOF. + if err == nil { + cooloff := !tail.sendLine(line) + if cooloff { + // Wait a second before seeking till the end of + // file when rate limit is reached. + msg := ("Too much log activity; waiting a second before resuming tailing") + offset, _ := tail.Tell() + tail.Lines <- &Line{msg, tail.lineNum, SeekInfo{Offset: offset}, time.Now(), errors.New(msg)} + select { + case <-time.After(time.Second): + case <-tail.Dying(): + return + } + if err := tail.seekEnd(); err != nil { + tail.Kill(err) + return + } + } + } else if err == io.EOF { + if !tail.Follow { + if line != "" { + tail.sendLine(line) + } + return + } + + if tail.Follow && line != "" { + tail.sendLine(line) + if err := tail.seekEnd(); err != nil { + tail.Kill(err) + return + } + } + + // When EOF is reached, wait for more data to become + // available. Wait strategy is based on the `tail.watcher` + // implementation (inotify or polling). + err := tail.waitForChanges() + if err != nil { + if err != ErrStop { + tail.Kill(err) + } + return + } + } else { + // non-EOF error + tail.Killf("Error reading %s: %s", tail.Filename, err) + return + } + + select { + case <-tail.Dying(): + if tail.Err() == errStopAtEOF { + continue + } + return + default: + } + } +} + +// waitForChanges waits until the file has been appended, deleted, +// moved or truncated. When moved or deleted - the file will be +// reopened if ReOpen is true. Truncated files are always reopened. +func (tail *Tail) waitForChanges() error { + if tail.changes == nil { + pos, err := tail.file.Seek(0, io.SeekCurrent) + if err != nil { + return err + } + tail.changes, err = tail.watcher.ChangeEvents(&tail.Tomb, pos) + if err != nil { + return err + } + } + + select { + case <-tail.changes.Modified: + return nil + case <-tail.changes.Deleted: + tail.changes = nil + if tail.ReOpen { + // XXX: we must not log from a library. + tail.Logger.Printf("Re-opening moved/deleted file %s ...", tail.Filename) + if err := tail.reopen(); err != nil { + return err + } + tail.Logger.Printf("Successfully reopened %s", tail.Filename) + tail.openReader() + return nil + } + tail.Logger.Printf("Stopping tail as file no longer exists: %s", tail.Filename) + return ErrStop + case <-tail.changes.Truncated: + // Always reopen truncated files (Follow is true) + tail.Logger.Printf("Re-opening truncated file %s ...", tail.Filename) + if err := tail.reopen(); err != nil { + return err + } + tail.Logger.Printf("Successfully reopened truncated %s", tail.Filename) + tail.openReader() + return nil + case <-tail.Dying(): + return ErrStop + } +} + +func (tail *Tail) openReader() { + tail.lk.Lock() + if tail.MaxLineSize > 0 { + // add 2 to account for newline characters + tail.reader = bufio.NewReaderSize(tail.file, tail.MaxLineSize+2) + } else { + tail.reader = bufio.NewReader(tail.file) + } + tail.lk.Unlock() +} + +func (tail *Tail) seekEnd() error { + return tail.seekTo(SeekInfo{Offset: 0, Whence: io.SeekEnd}) +} + +func (tail *Tail) seekTo(pos SeekInfo) error { + _, err := tail.file.Seek(pos.Offset, pos.Whence) + if err != nil { + return fmt.Errorf("Seek error on %s: %s", tail.Filename, err) + } + // Reset the read buffer whenever the file is re-seek'ed + tail.reader.Reset(tail.file) + return nil +} + +// sendLine sends the line(s) to Lines channel, splitting longer lines +// if necessary. Return false if rate limit is reached. +func (tail *Tail) sendLine(line string) bool { + now := time.Now() + lines := []string{line} + + // Split longer lines + if tail.MaxLineSize > 0 && len(line) > tail.MaxLineSize { + lines = util.PartitionString(line, tail.MaxLineSize) + } + + for _, line := range lines { + tail.lineNum++ + offset, _ := tail.Tell() + select { + case tail.Lines <- &Line{line, tail.lineNum, SeekInfo{Offset: offset}, now, nil}: + case <-tail.Dying(): + return true + } + } + + if tail.Config.RateLimiter != nil { + ok := tail.Config.RateLimiter.Pour(uint16(len(lines))) + if !ok { + tail.Logger.Printf("Leaky bucket full (%v); entering 1s cooloff period.", + tail.Filename) + return false + } + } + + return true +} + +// Cleanup removes inotify watches added by the tail package. This function is +// meant to be invoked from a process's exit handler. Linux kernel may not +// automatically remove inotify watches after the process exits. +// If you plan to re-read a file, don't call Cleanup in between. +func (tail *Tail) Cleanup() { + watch.Cleanup(tail.Filename) +} diff --git a/vendor/github.com/nxadm/tail/tail_posix.go b/vendor/github.com/nxadm/tail/tail_posix.go new file mode 100644 index 00000000..23e071de --- /dev/null +++ b/vendor/github.com/nxadm/tail/tail_posix.go @@ -0,0 +1,17 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// +build !windows + +package tail + +import ( + "os" +) + +// Deprecated: this function is only useful internally and, as such, +// it will be removed from the API in a future major release. +// +// OpenFile proxies a os.Open call for a file so it can be correctly tailed +// on POSIX and non-POSIX OSes like MS Windows. +func OpenFile(name string) (file *os.File, err error) { + return os.Open(name) +} diff --git a/vendor/github.com/nxadm/tail/tail_windows.go b/vendor/github.com/nxadm/tail/tail_windows.go new file mode 100644 index 00000000..da0d2f39 --- /dev/null +++ b/vendor/github.com/nxadm/tail/tail_windows.go @@ -0,0 +1,19 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// +build windows + +package tail + +import ( + "os" + + "github.com/nxadm/tail/winfile" +) + +// Deprecated: this function is only useful internally and, as such, +// it will be removed from the API in a future major release. +// +// OpenFile proxies a os.Open call for a file so it can be correctly tailed +// on POSIX and non-POSIX OSes like MS Windows. +func OpenFile(name string) (file *os.File, err error) { + return winfile.OpenFile(name, os.O_RDONLY, 0) +} diff --git a/vendor/github.com/nxadm/tail/util/util.go b/vendor/github.com/nxadm/tail/util/util.go new file mode 100644 index 00000000..b64caa21 --- /dev/null +++ b/vendor/github.com/nxadm/tail/util/util.go @@ -0,0 +1,49 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// Copyright (c) 2015 HPE Software Inc. All rights reserved. +// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. + +package util + +import ( + "fmt" + "log" + "os" + "runtime/debug" +) + +type Logger struct { + *log.Logger +} + +var LOGGER = &Logger{log.New(os.Stderr, "", log.LstdFlags)} + +// fatal is like panic except it displays only the current goroutine's stack. +func Fatal(format string, v ...interface{}) { + // https://github.com/nxadm/log/blob/master/log.go#L45 + LOGGER.Output(2, fmt.Sprintf("FATAL -- "+format, v...)+"\n"+string(debug.Stack())) + os.Exit(1) +} + +// partitionString partitions the string into chunks of given size, +// with the last chunk of variable size. +func PartitionString(s string, chunkSize int) []string { + if chunkSize <= 0 { + panic("invalid chunkSize") + } + length := len(s) + chunks := 1 + length/chunkSize + start := 0 + end := chunkSize + parts := make([]string, 0, chunks) + for { + if end > length { + end = length + } + parts = append(parts, s[start:end]) + if end == length { + break + } + start, end = end, end+chunkSize + } + return parts +} diff --git a/vendor/github.com/nxadm/tail/watch/filechanges.go b/vendor/github.com/nxadm/tail/watch/filechanges.go new file mode 100644 index 00000000..5b65f42a --- /dev/null +++ b/vendor/github.com/nxadm/tail/watch/filechanges.go @@ -0,0 +1,37 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +package watch + +type FileChanges struct { + Modified chan bool // Channel to get notified of modifications + Truncated chan bool // Channel to get notified of truncations + Deleted chan bool // Channel to get notified of deletions/renames +} + +func NewFileChanges() *FileChanges { + return &FileChanges{ + make(chan bool, 1), make(chan bool, 1), make(chan bool, 1)} +} + +func (fc *FileChanges) NotifyModified() { + sendOnlyIfEmpty(fc.Modified) +} + +func (fc *FileChanges) NotifyTruncated() { + sendOnlyIfEmpty(fc.Truncated) +} + +func (fc *FileChanges) NotifyDeleted() { + sendOnlyIfEmpty(fc.Deleted) +} + +// sendOnlyIfEmpty sends on a bool channel only if the channel has no +// backlog to be read by other goroutines. This concurrency pattern +// can be used to notify other goroutines if and only if they are +// looking for it (i.e., subsequent notifications can be compressed +// into one). +func sendOnlyIfEmpty(ch chan bool) { + select { + case ch <- true: + default: + } +} diff --git a/vendor/github.com/nxadm/tail/watch/inotify.go b/vendor/github.com/nxadm/tail/watch/inotify.go new file mode 100644 index 00000000..cbd11ad8 --- /dev/null +++ b/vendor/github.com/nxadm/tail/watch/inotify.go @@ -0,0 +1,136 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// Copyright (c) 2015 HPE Software Inc. All rights reserved. +// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. + +package watch + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/nxadm/tail/util" + + "github.com/fsnotify/fsnotify" + "gopkg.in/tomb.v1" +) + +// InotifyFileWatcher uses inotify to monitor file changes. +type InotifyFileWatcher struct { + Filename string + Size int64 +} + +func NewInotifyFileWatcher(filename string) *InotifyFileWatcher { + fw := &InotifyFileWatcher{filepath.Clean(filename), 0} + return fw +} + +func (fw *InotifyFileWatcher) BlockUntilExists(t *tomb.Tomb) error { + err := WatchCreate(fw.Filename) + if err != nil { + return err + } + defer RemoveWatchCreate(fw.Filename) + + // Do a real check now as the file might have been created before + // calling `WatchFlags` above. + if _, err = os.Stat(fw.Filename); !os.IsNotExist(err) { + // file exists, or stat returned an error. + return err + } + + events := Events(fw.Filename) + + for { + select { + case evt, ok := <-events: + if !ok { + return fmt.Errorf("inotify watcher has been closed") + } + evtName, err := filepath.Abs(evt.Name) + if err != nil { + return err + } + fwFilename, err := filepath.Abs(fw.Filename) + if err != nil { + return err + } + if evtName == fwFilename { + return nil + } + case <-t.Dying(): + return tomb.ErrDying + } + } + panic("unreachable") +} + +func (fw *InotifyFileWatcher) ChangeEvents(t *tomb.Tomb, pos int64) (*FileChanges, error) { + err := Watch(fw.Filename) + if err != nil { + return nil, err + } + + changes := NewFileChanges() + fw.Size = pos + + go func() { + + events := Events(fw.Filename) + + for { + prevSize := fw.Size + + var evt fsnotify.Event + var ok bool + + select { + case evt, ok = <-events: + if !ok { + RemoveWatch(fw.Filename) + return + } + case <-t.Dying(): + RemoveWatch(fw.Filename) + return + } + + switch { + case evt.Op&fsnotify.Remove == fsnotify.Remove: + fallthrough + + case evt.Op&fsnotify.Rename == fsnotify.Rename: + RemoveWatch(fw.Filename) + changes.NotifyDeleted() + return + + //With an open fd, unlink(fd) - inotify returns IN_ATTRIB (==fsnotify.Chmod) + case evt.Op&fsnotify.Chmod == fsnotify.Chmod: + fallthrough + + case evt.Op&fsnotify.Write == fsnotify.Write: + fi, err := os.Stat(fw.Filename) + if err != nil { + if os.IsNotExist(err) { + RemoveWatch(fw.Filename) + changes.NotifyDeleted() + return + } + // XXX: report this error back to the user + util.Fatal("Failed to stat file %v: %v", fw.Filename, err) + } + fw.Size = fi.Size() + + if prevSize > 0 && prevSize > fw.Size { + changes.NotifyTruncated() + } else { + changes.NotifyModified() + } + prevSize = fw.Size + } + } + }() + + return changes, nil +} diff --git a/vendor/github.com/nxadm/tail/watch/inotify_tracker.go b/vendor/github.com/nxadm/tail/watch/inotify_tracker.go new file mode 100644 index 00000000..cb9572a0 --- /dev/null +++ b/vendor/github.com/nxadm/tail/watch/inotify_tracker.go @@ -0,0 +1,249 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// Copyright (c) 2015 HPE Software Inc. All rights reserved. +// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. + +package watch + +import ( + "log" + "os" + "path/filepath" + "sync" + "syscall" + + "github.com/nxadm/tail/util" + + "github.com/fsnotify/fsnotify" +) + +type InotifyTracker struct { + mux sync.Mutex + watcher *fsnotify.Watcher + chans map[string]chan fsnotify.Event + done map[string]chan bool + watchNums map[string]int + watch chan *watchInfo + remove chan *watchInfo + error chan error +} + +type watchInfo struct { + op fsnotify.Op + fname string +} + +func (this *watchInfo) isCreate() bool { + return this.op == fsnotify.Create +} + +var ( + // globally shared InotifyTracker; ensures only one fsnotify.Watcher is used + shared *InotifyTracker + + // these are used to ensure the shared InotifyTracker is run exactly once + once = sync.Once{} + goRun = func() { + shared = &InotifyTracker{ + mux: sync.Mutex{}, + chans: make(map[string]chan fsnotify.Event), + done: make(map[string]chan bool), + watchNums: make(map[string]int), + watch: make(chan *watchInfo), + remove: make(chan *watchInfo), + error: make(chan error), + } + go shared.run() + } + + logger = log.New(os.Stderr, "", log.LstdFlags) +) + +// Watch signals the run goroutine to begin watching the input filename +func Watch(fname string) error { + return watch(&watchInfo{ + fname: fname, + }) +} + +// Watch create signals the run goroutine to begin watching the input filename +// if call the WatchCreate function, don't call the Cleanup, call the RemoveWatchCreate +func WatchCreate(fname string) error { + return watch(&watchInfo{ + op: fsnotify.Create, + fname: fname, + }) +} + +func watch(winfo *watchInfo) error { + // start running the shared InotifyTracker if not already running + once.Do(goRun) + + winfo.fname = filepath.Clean(winfo.fname) + shared.watch <- winfo + return <-shared.error +} + +// RemoveWatch signals the run goroutine to remove the watch for the input filename +func RemoveWatch(fname string) error { + return remove(&watchInfo{ + fname: fname, + }) +} + +// RemoveWatch create signals the run goroutine to remove the watch for the input filename +func RemoveWatchCreate(fname string) error { + return remove(&watchInfo{ + op: fsnotify.Create, + fname: fname, + }) +} + +func remove(winfo *watchInfo) error { + // start running the shared InotifyTracker if not already running + once.Do(goRun) + + winfo.fname = filepath.Clean(winfo.fname) + shared.mux.Lock() + done := shared.done[winfo.fname] + if done != nil { + delete(shared.done, winfo.fname) + close(done) + } + shared.mux.Unlock() + + shared.remove <- winfo + return <-shared.error +} + +// Events returns a channel to which FileEvents corresponding to the input filename +// will be sent. This channel will be closed when removeWatch is called on this +// filename. +func Events(fname string) <-chan fsnotify.Event { + shared.mux.Lock() + defer shared.mux.Unlock() + + return shared.chans[fname] +} + +// Cleanup removes the watch for the input filename if necessary. +func Cleanup(fname string) error { + return RemoveWatch(fname) +} + +// watchFlags calls fsnotify.WatchFlags for the input filename and flags, creating +// a new Watcher if the previous Watcher was closed. +func (shared *InotifyTracker) addWatch(winfo *watchInfo) error { + shared.mux.Lock() + defer shared.mux.Unlock() + + if shared.chans[winfo.fname] == nil { + shared.chans[winfo.fname] = make(chan fsnotify.Event) + } + if shared.done[winfo.fname] == nil { + shared.done[winfo.fname] = make(chan bool) + } + + fname := winfo.fname + if winfo.isCreate() { + // Watch for new files to be created in the parent directory. + fname = filepath.Dir(fname) + } + + var err error + // already in inotify watch + if shared.watchNums[fname] == 0 { + err = shared.watcher.Add(fname) + } + if err == nil { + shared.watchNums[fname]++ + } + return err +} + +// removeWatch calls fsnotify.RemoveWatch for the input filename and closes the +// corresponding events channel. +func (shared *InotifyTracker) removeWatch(winfo *watchInfo) error { + shared.mux.Lock() + + ch := shared.chans[winfo.fname] + if ch != nil { + delete(shared.chans, winfo.fname) + close(ch) + } + + fname := winfo.fname + if winfo.isCreate() { + // Watch for new files to be created in the parent directory. + fname = filepath.Dir(fname) + } + shared.watchNums[fname]-- + watchNum := shared.watchNums[fname] + if watchNum == 0 { + delete(shared.watchNums, fname) + } + shared.mux.Unlock() + + var err error + // If we were the last ones to watch this file, unsubscribe from inotify. + // This needs to happen after releasing the lock because fsnotify waits + // synchronously for the kernel to acknowledge the removal of the watch + // for this file, which causes us to deadlock if we still held the lock. + if watchNum == 0 { + err = shared.watcher.Remove(fname) + } + + return err +} + +// sendEvent sends the input event to the appropriate Tail. +func (shared *InotifyTracker) sendEvent(event fsnotify.Event) { + name := filepath.Clean(event.Name) + + shared.mux.Lock() + ch := shared.chans[name] + done := shared.done[name] + shared.mux.Unlock() + + if ch != nil && done != nil { + select { + case ch <- event: + case <-done: + } + } +} + +// run starts the goroutine in which the shared struct reads events from its +// Watcher's Event channel and sends the events to the appropriate Tail. +func (shared *InotifyTracker) run() { + watcher, err := fsnotify.NewWatcher() + if err != nil { + util.Fatal("failed to create Watcher") + } + shared.watcher = watcher + + for { + select { + case winfo := <-shared.watch: + shared.error <- shared.addWatch(winfo) + + case winfo := <-shared.remove: + shared.error <- shared.removeWatch(winfo) + + case event, open := <-shared.watcher.Events: + if !open { + return + } + shared.sendEvent(event) + + case err, open := <-shared.watcher.Errors: + if !open { + return + } else if err != nil { + sysErr, ok := err.(*os.SyscallError) + if !ok || sysErr.Err != syscall.EINTR { + logger.Printf("Error in Watcher Error channel: %s", err) + } + } + } + } +} diff --git a/vendor/github.com/nxadm/tail/watch/polling.go b/vendor/github.com/nxadm/tail/watch/polling.go new file mode 100644 index 00000000..74e10aa4 --- /dev/null +++ b/vendor/github.com/nxadm/tail/watch/polling.go @@ -0,0 +1,119 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// Copyright (c) 2015 HPE Software Inc. All rights reserved. +// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. + +package watch + +import ( + "os" + "runtime" + "time" + + "github.com/nxadm/tail/util" + "gopkg.in/tomb.v1" +) + +// PollingFileWatcher polls the file for changes. +type PollingFileWatcher struct { + Filename string + Size int64 +} + +func NewPollingFileWatcher(filename string) *PollingFileWatcher { + fw := &PollingFileWatcher{filename, 0} + return fw +} + +var POLL_DURATION time.Duration + +func (fw *PollingFileWatcher) BlockUntilExists(t *tomb.Tomb) error { + for { + if _, err := os.Stat(fw.Filename); err == nil { + return nil + } else if !os.IsNotExist(err) { + return err + } + select { + case <-time.After(POLL_DURATION): + continue + case <-t.Dying(): + return tomb.ErrDying + } + } + panic("unreachable") +} + +func (fw *PollingFileWatcher) ChangeEvents(t *tomb.Tomb, pos int64) (*FileChanges, error) { + origFi, err := os.Stat(fw.Filename) + if err != nil { + return nil, err + } + + changes := NewFileChanges() + var prevModTime time.Time + + // XXX: use tomb.Tomb to cleanly manage these goroutines. replace + // the fatal (below) with tomb's Kill. + + fw.Size = pos + + go func() { + prevSize := fw.Size + for { + select { + case <-t.Dying(): + return + default: + } + + time.Sleep(POLL_DURATION) + fi, err := os.Stat(fw.Filename) + if err != nil { + // Windows cannot delete a file if a handle is still open (tail keeps one open) + // so it gives access denied to anything trying to read it until all handles are released. + if os.IsNotExist(err) || (runtime.GOOS == "windows" && os.IsPermission(err)) { + // File does not exist (has been deleted). + changes.NotifyDeleted() + return + } + + // XXX: report this error back to the user + util.Fatal("Failed to stat file %v: %v", fw.Filename, err) + } + + // File got moved/renamed? + if !os.SameFile(origFi, fi) { + changes.NotifyDeleted() + return + } + + // File got truncated? + fw.Size = fi.Size() + if prevSize > 0 && prevSize > fw.Size { + changes.NotifyTruncated() + prevSize = fw.Size + continue + } + // File got bigger? + if prevSize > 0 && prevSize < fw.Size { + changes.NotifyModified() + prevSize = fw.Size + continue + } + prevSize = fw.Size + + // File was appended to (changed)? + modTime := fi.ModTime() + if modTime != prevModTime { + prevModTime = modTime + changes.NotifyModified() + } + } + }() + + return changes, nil +} + +func init() { + POLL_DURATION = 250 * time.Millisecond +} diff --git a/vendor/github.com/nxadm/tail/watch/watch.go b/vendor/github.com/nxadm/tail/watch/watch.go new file mode 100644 index 00000000..2b511280 --- /dev/null +++ b/vendor/github.com/nxadm/tail/watch/watch.go @@ -0,0 +1,21 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// Copyright (c) 2015 HPE Software Inc. All rights reserved. +// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. + +package watch + +import "gopkg.in/tomb.v1" + +// FileWatcher monitors file-level events. +type FileWatcher interface { + // BlockUntilExists blocks until the file comes into existence. + BlockUntilExists(*tomb.Tomb) error + + // ChangeEvents reports on changes to a file, be it modification, + // deletion, renames or truncations. Returned FileChanges group of + // channels will be closed, thus become unusable, after a deletion + // or truncation event. + // In order to properly report truncations, ChangeEvents requires + // the caller to pass their current offset in the file. + ChangeEvents(*tomb.Tomb, int64) (*FileChanges, error) +} diff --git a/vendor/github.com/nxadm/tail/winfile/winfile.go b/vendor/github.com/nxadm/tail/winfile/winfile.go new file mode 100644 index 00000000..4562ac7c --- /dev/null +++ b/vendor/github.com/nxadm/tail/winfile/winfile.go @@ -0,0 +1,93 @@ +// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail +// +build windows + +package winfile + +import ( + "os" + "syscall" + "unsafe" +) + +// issue also described here +//https://codereview.appspot.com/8203043/ + +// https://github.com/jnwhiteh/golang/blob/master/src/pkg/syscall/syscall_windows.go#L218 +func Open(path string, mode int, perm uint32) (fd syscall.Handle, err error) { + if len(path) == 0 { + return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND + } + pathp, err := syscall.UTF16PtrFromString(path) + if err != nil { + return syscall.InvalidHandle, err + } + var access uint32 + switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) { + case syscall.O_RDONLY: + access = syscall.GENERIC_READ + case syscall.O_WRONLY: + access = syscall.GENERIC_WRITE + case syscall.O_RDWR: + access = syscall.GENERIC_READ | syscall.GENERIC_WRITE + } + if mode&syscall.O_CREAT != 0 { + access |= syscall.GENERIC_WRITE + } + if mode&syscall.O_APPEND != 0 { + access &^= syscall.GENERIC_WRITE + access |= syscall.FILE_APPEND_DATA + } + sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE) + var sa *syscall.SecurityAttributes + if mode&syscall.O_CLOEXEC == 0 { + sa = makeInheritSa() + } + var createmode uint32 + switch { + case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL): + createmode = syscall.CREATE_NEW + case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC): + createmode = syscall.CREATE_ALWAYS + case mode&syscall.O_CREAT == syscall.O_CREAT: + createmode = syscall.OPEN_ALWAYS + case mode&syscall.O_TRUNC == syscall.O_TRUNC: + createmode = syscall.TRUNCATE_EXISTING + default: + createmode = syscall.OPEN_EXISTING + } + h, e := syscall.CreateFile(pathp, access, sharemode, sa, createmode, syscall.FILE_ATTRIBUTE_NORMAL, 0) + return h, e +} + +// https://github.com/jnwhiteh/golang/blob/master/src/pkg/syscall/syscall_windows.go#L211 +func makeInheritSa() *syscall.SecurityAttributes { + var sa syscall.SecurityAttributes + sa.Length = uint32(unsafe.Sizeof(sa)) + sa.InheritHandle = 1 + return &sa +} + +// https://github.com/jnwhiteh/golang/blob/master/src/pkg/os/file_windows.go#L133 +func OpenFile(name string, flag int, perm os.FileMode) (file *os.File, err error) { + r, e := Open(name, flag|syscall.O_CLOEXEC, syscallMode(perm)) + if e != nil { + return nil, e + } + return os.NewFile(uintptr(r), name), nil +} + +// https://github.com/jnwhiteh/golang/blob/master/src/pkg/os/file_posix.go#L61 +func syscallMode(i os.FileMode) (o uint32) { + o |= uint32(i.Perm()) + if i&os.ModeSetuid != 0 { + o |= syscall.S_ISUID + } + if i&os.ModeSetgid != 0 { + o |= syscall.S_ISGID + } + if i&os.ModeSticky != 0 { + o |= syscall.S_ISVTX + } + // No mapping for Go's ModeTemporary (plan9 only). + return +} diff --git a/vendor/github.com/onsi/ginkgo/LICENSE b/vendor/github.com/onsi/ginkgo/LICENSE new file mode 100644 index 00000000..9415ee72 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2013-2014 Onsi Fakhouri + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/onsi/ginkgo/config/config.go b/vendor/github.com/onsi/ginkgo/config/config.go new file mode 100644 index 00000000..5f3f4396 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/config/config.go @@ -0,0 +1,232 @@ +/* +Ginkgo accepts a number of configuration options. + +These are documented [here](http://onsi.github.io/ginkgo/#the-ginkgo-cli) + +You can also learn more via + + ginkgo help + +or (I kid you not): + + go test -asdf +*/ +package config + +import ( + "flag" + "time" + + "fmt" +) + +const VERSION = "1.16.4" + +type GinkgoConfigType struct { + RandomSeed int64 + RandomizeAllSpecs bool + RegexScansFilePath bool + FocusStrings []string + SkipStrings []string + SkipMeasurements bool + FailOnPending bool + FailFast bool + FlakeAttempts int + EmitSpecProgress bool + DryRun bool + DebugParallel bool + + ParallelNode int + ParallelTotal int + SyncHost string + StreamHost string +} + +var GinkgoConfig = GinkgoConfigType{} + +type DefaultReporterConfigType struct { + NoColor bool + SlowSpecThreshold float64 + NoisyPendings bool + NoisySkippings bool + Succinct bool + Verbose bool + FullTrace bool + ReportPassed bool + ReportFile string +} + +var DefaultReporterConfig = DefaultReporterConfigType{} + +func processPrefix(prefix string) string { + if prefix != "" { + prefix += "." + } + return prefix +} + +type flagFunc func(string) + +func (f flagFunc) String() string { return "" } +func (f flagFunc) Set(s string) error { f(s); return nil } + +func Flags(flagSet *flag.FlagSet, prefix string, includeParallelFlags bool) { + prefix = processPrefix(prefix) + flagSet.Int64Var(&(GinkgoConfig.RandomSeed), prefix+"seed", time.Now().Unix(), "The seed used to randomize the spec suite.") + flagSet.BoolVar(&(GinkgoConfig.RandomizeAllSpecs), prefix+"randomizeAllSpecs", false, "If set, ginkgo will randomize all specs together. By default, ginkgo only randomizes the top level Describe, Context and When groups.") + flagSet.BoolVar(&(GinkgoConfig.SkipMeasurements), prefix+"skipMeasurements", false, "If set, ginkgo will skip any measurement specs.") + flagSet.BoolVar(&(GinkgoConfig.FailOnPending), prefix+"failOnPending", false, "If set, ginkgo will mark the test suite as failed if any specs are pending.") + flagSet.BoolVar(&(GinkgoConfig.FailFast), prefix+"failFast", false, "If set, ginkgo will stop running a test suite after a failure occurs.") + + flagSet.BoolVar(&(GinkgoConfig.DryRun), prefix+"dryRun", false, "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v.") + + flagSet.Var(flagFunc(flagFocus), prefix+"focus", "If set, ginkgo will only run specs that match this regular expression. Can be specified multiple times, values are ORed.") + flagSet.Var(flagFunc(flagSkip), prefix+"skip", "If set, ginkgo will only run specs that do not match this regular expression. Can be specified multiple times, values are ORed.") + + flagSet.BoolVar(&(GinkgoConfig.RegexScansFilePath), prefix+"regexScansFilePath", false, "If set, ginkgo regex matching also will look at the file path (code location).") + + flagSet.IntVar(&(GinkgoConfig.FlakeAttempts), prefix+"flakeAttempts", 1, "Make up to this many attempts to run each spec. Please note that if any of the attempts succeed, the suite will not be failed. But any failures will still be recorded.") + + flagSet.BoolVar(&(GinkgoConfig.EmitSpecProgress), prefix+"progress", false, "If set, ginkgo will emit progress information as each spec runs to the GinkgoWriter.") + + flagSet.BoolVar(&(GinkgoConfig.DebugParallel), prefix+"debug", false, "If set, ginkgo will emit node output to files when running in parallel.") + + if includeParallelFlags { + flagSet.IntVar(&(GinkgoConfig.ParallelNode), prefix+"parallel.node", 1, "This worker node's (one-indexed) node number. For running specs in parallel.") + flagSet.IntVar(&(GinkgoConfig.ParallelTotal), prefix+"parallel.total", 1, "The total number of worker nodes. For running specs in parallel.") + flagSet.StringVar(&(GinkgoConfig.SyncHost), prefix+"parallel.synchost", "", "The address for the server that will synchronize the running nodes.") + flagSet.StringVar(&(GinkgoConfig.StreamHost), prefix+"parallel.streamhost", "", "The address for the server that the running nodes should stream data to.") + } + + flagSet.BoolVar(&(DefaultReporterConfig.NoColor), prefix+"noColor", false, "If set, suppress color output in default reporter.") + flagSet.Float64Var(&(DefaultReporterConfig.SlowSpecThreshold), prefix+"slowSpecThreshold", 5.0, "(in seconds) Specs that take longer to run than this threshold are flagged as slow by the default reporter.") + flagSet.BoolVar(&(DefaultReporterConfig.NoisyPendings), prefix+"noisyPendings", true, "If set, default reporter will shout about pending tests.") + flagSet.BoolVar(&(DefaultReporterConfig.NoisySkippings), prefix+"noisySkippings", true, "If set, default reporter will shout about skipping tests.") + flagSet.BoolVar(&(DefaultReporterConfig.Verbose), prefix+"v", false, "If set, default reporter print out all specs as they begin.") + flagSet.BoolVar(&(DefaultReporterConfig.Succinct), prefix+"succinct", false, "If set, default reporter prints out a very succinct report") + flagSet.BoolVar(&(DefaultReporterConfig.FullTrace), prefix+"trace", false, "If set, default reporter prints out the full stack trace when a failure occurs") + flagSet.BoolVar(&(DefaultReporterConfig.ReportPassed), prefix+"reportPassed", false, "If set, default reporter prints out captured output of passed tests.") + flagSet.StringVar(&(DefaultReporterConfig.ReportFile), prefix+"reportFile", "", "Override the default reporter output file path.") + +} + +func BuildFlagArgs(prefix string, ginkgo GinkgoConfigType, reporter DefaultReporterConfigType) []string { + prefix = processPrefix(prefix) + result := make([]string, 0) + + if ginkgo.RandomSeed > 0 { + result = append(result, fmt.Sprintf("--%sseed=%d", prefix, ginkgo.RandomSeed)) + } + + if ginkgo.RandomizeAllSpecs { + result = append(result, fmt.Sprintf("--%srandomizeAllSpecs", prefix)) + } + + if ginkgo.SkipMeasurements { + result = append(result, fmt.Sprintf("--%sskipMeasurements", prefix)) + } + + if ginkgo.FailOnPending { + result = append(result, fmt.Sprintf("--%sfailOnPending", prefix)) + } + + if ginkgo.FailFast { + result = append(result, fmt.Sprintf("--%sfailFast", prefix)) + } + + if ginkgo.DryRun { + result = append(result, fmt.Sprintf("--%sdryRun", prefix)) + } + + for _, s := range ginkgo.FocusStrings { + result = append(result, fmt.Sprintf("--%sfocus=%s", prefix, s)) + } + + for _, s := range ginkgo.SkipStrings { + result = append(result, fmt.Sprintf("--%sskip=%s", prefix, s)) + } + + if ginkgo.FlakeAttempts > 1 { + result = append(result, fmt.Sprintf("--%sflakeAttempts=%d", prefix, ginkgo.FlakeAttempts)) + } + + if ginkgo.EmitSpecProgress { + result = append(result, fmt.Sprintf("--%sprogress", prefix)) + } + + if ginkgo.DebugParallel { + result = append(result, fmt.Sprintf("--%sdebug", prefix)) + } + + if ginkgo.ParallelNode != 0 { + result = append(result, fmt.Sprintf("--%sparallel.node=%d", prefix, ginkgo.ParallelNode)) + } + + if ginkgo.ParallelTotal != 0 { + result = append(result, fmt.Sprintf("--%sparallel.total=%d", prefix, ginkgo.ParallelTotal)) + } + + if ginkgo.StreamHost != "" { + result = append(result, fmt.Sprintf("--%sparallel.streamhost=%s", prefix, ginkgo.StreamHost)) + } + + if ginkgo.SyncHost != "" { + result = append(result, fmt.Sprintf("--%sparallel.synchost=%s", prefix, ginkgo.SyncHost)) + } + + if ginkgo.RegexScansFilePath { + result = append(result, fmt.Sprintf("--%sregexScansFilePath", prefix)) + } + + if reporter.NoColor { + result = append(result, fmt.Sprintf("--%snoColor", prefix)) + } + + if reporter.SlowSpecThreshold > 0 { + result = append(result, fmt.Sprintf("--%sslowSpecThreshold=%.5f", prefix, reporter.SlowSpecThreshold)) + } + + if !reporter.NoisyPendings { + result = append(result, fmt.Sprintf("--%snoisyPendings=false", prefix)) + } + + if !reporter.NoisySkippings { + result = append(result, fmt.Sprintf("--%snoisySkippings=false", prefix)) + } + + if reporter.Verbose { + result = append(result, fmt.Sprintf("--%sv", prefix)) + } + + if reporter.Succinct { + result = append(result, fmt.Sprintf("--%ssuccinct", prefix)) + } + + if reporter.FullTrace { + result = append(result, fmt.Sprintf("--%strace", prefix)) + } + + if reporter.ReportPassed { + result = append(result, fmt.Sprintf("--%sreportPassed", prefix)) + } + + if reporter.ReportFile != "" { + result = append(result, fmt.Sprintf("--%sreportFile=%s", prefix, reporter.ReportFile)) + } + + return result +} + +// flagFocus implements the -focus flag. +func flagFocus(arg string) { + if arg != "" { + GinkgoConfig.FocusStrings = append(GinkgoConfig.FocusStrings, arg) + } +} + +// flagSkip implements the -skip flag. +func flagSkip(arg string) { + if arg != "" { + GinkgoConfig.SkipStrings = append(GinkgoConfig.SkipStrings, arg) + } +} diff --git a/vendor/github.com/onsi/ginkgo/formatter/formatter.go b/vendor/github.com/onsi/ginkgo/formatter/formatter.go new file mode 100644 index 00000000..30d7cbe1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/formatter/formatter.go @@ -0,0 +1,190 @@ +package formatter + +import ( + "fmt" + "regexp" + "strings" +) + +const COLS = 80 + +type ColorMode uint8 + +const ( + ColorModeNone ColorMode = iota + ColorModeTerminal + ColorModePassthrough +) + +var SingletonFormatter = New(ColorModeTerminal) + +func F(format string, args ...interface{}) string { + return SingletonFormatter.F(format, args...) +} + +func Fi(indentation uint, format string, args ...interface{}) string { + return SingletonFormatter.Fi(indentation, format, args...) +} + +func Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string { + return SingletonFormatter.Fiw(indentation, maxWidth, format, args...) +} + +type Formatter struct { + ColorMode ColorMode + colors map[string]string + styleRe *regexp.Regexp + preserveColorStylingTags bool +} + +func NewWithNoColorBool(noColor bool) Formatter { + if noColor { + return New(ColorModeNone) + } + return New(ColorModeTerminal) +} + +func New(colorMode ColorMode) Formatter { + f := Formatter{ + ColorMode: colorMode, + colors: map[string]string{ + "/": "\x1b[0m", + "bold": "\x1b[1m", + "underline": "\x1b[4m", + + "red": "\x1b[38;5;9m", + "orange": "\x1b[38;5;214m", + "coral": "\x1b[38;5;204m", + "magenta": "\x1b[38;5;13m", + "green": "\x1b[38;5;10m", + "dark-green": "\x1b[38;5;28m", + "yellow": "\x1b[38;5;11m", + "light-yellow": "\x1b[38;5;228m", + "cyan": "\x1b[38;5;14m", + "gray": "\x1b[38;5;243m", + "light-gray": "\x1b[38;5;246m", + "blue": "\x1b[38;5;12m", + }, + } + colors := []string{} + for color := range f.colors { + colors = append(colors, color) + } + f.styleRe = regexp.MustCompile("{{(" + strings.Join(colors, "|") + ")}}") + return f +} + +func (f Formatter) F(format string, args ...interface{}) string { + return f.Fi(0, format, args...) +} + +func (f Formatter) Fi(indentation uint, format string, args ...interface{}) string { + return f.Fiw(indentation, 0, format, args...) +} + +func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string { + out := fmt.Sprintf(f.style(format), args...) + + if indentation == 0 && maxWidth == 0 { + return out + } + + lines := strings.Split(out, "\n") + + if maxWidth != 0 { + outLines := []string{} + + maxWidth = maxWidth - indentation*2 + for _, line := range lines { + if f.length(line) <= maxWidth { + outLines = append(outLines, line) + continue + } + outWords := []string{} + length := uint(0) + words := strings.Split(line, " ") + for _, word := range words { + wordLength := f.length(word) + if length+wordLength <= maxWidth { + length += wordLength + outWords = append(outWords, word) + continue + } + outLines = append(outLines, strings.Join(outWords, " ")) + outWords = []string{word} + length = wordLength + } + if len(outWords) > 0 { + outLines = append(outLines, strings.Join(outWords, " ")) + } + } + + lines = outLines + } + + if indentation == 0 { + return strings.Join(lines, "\n") + } + + padding := strings.Repeat(" ", int(indentation)) + for i := range lines { + if lines[i] != "" { + lines[i] = padding + lines[i] + } + } + + return strings.Join(lines, "\n") +} + +func (f Formatter) length(styled string) uint { + n := uint(0) + inStyle := false + for _, b := range styled { + if inStyle { + if b == 'm' { + inStyle = false + } + continue + } + if b == '\x1b' { + inStyle = true + continue + } + n += 1 + } + return n +} + +func (f Formatter) CycleJoin(elements []string, joiner string, cycle []string) string { + if len(elements) == 0 { + return "" + } + n := len(cycle) + out := "" + for i, text := range elements { + out += cycle[i%n] + text + if i < len(elements)-1 { + out += joiner + } + } + out += "{{/}}" + return f.style(out) +} + +func (f Formatter) style(s string) string { + switch f.ColorMode { + case ColorModeNone: + return f.styleRe.ReplaceAllString(s, "") + case ColorModePassthrough: + return s + case ColorModeTerminal: + return f.styleRe.ReplaceAllStringFunc(s, func(match string) string { + if out, ok := f.colors[strings.Trim(match, "{}")]; ok { + return out + } + return match + }) + } + + return "" +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/bootstrap_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/bootstrap_command.go new file mode 100644 index 00000000..6f5af391 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/bootstrap_command.go @@ -0,0 +1,200 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "text/template" + + "go/build" + + sprig "github.com/go-task/slim-sprig" + "github.com/onsi/ginkgo/ginkgo/nodot" +) + +func BuildBootstrapCommand() *Command { + var ( + agouti, noDot, internal bool + customBootstrapFile string + ) + flagSet := flag.NewFlagSet("bootstrap", flag.ExitOnError) + flagSet.BoolVar(&agouti, "agouti", false, "If set, bootstrap will generate a bootstrap file for writing Agouti tests") + flagSet.BoolVar(&noDot, "nodot", false, "If set, bootstrap will generate a bootstrap file that does not . import ginkgo and gomega") + flagSet.BoolVar(&internal, "internal", false, "If set, generate will generate a test file that uses the regular package name") + flagSet.StringVar(&customBootstrapFile, "template", "", "If specified, generate will use the contents of the file passed as the bootstrap template") + + return &Command{ + Name: "bootstrap", + FlagSet: flagSet, + UsageCommand: "ginkgo bootstrap ", + Usage: []string{ + "Bootstrap a test suite for the current package", + "Accepts the following flags:", + }, + Command: func(args []string, additionalArgs []string) { + generateBootstrap(agouti, noDot, internal, customBootstrapFile) + }, + } +} + +var bootstrapText = `package {{.Package}} + +import ( + "testing" + + {{.GinkgoImport}} + {{.GomegaImport}} +) + +func Test{{.FormattedName}}(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "{{.FormattedName}} Suite") +} +` + +var agoutiBootstrapText = `package {{.Package}} + +import ( + "testing" + + {{.GinkgoImport}} + {{.GomegaImport}} + "github.com/sclevine/agouti" +) + +func Test{{.FormattedName}}(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "{{.FormattedName}} Suite") +} + +var agoutiDriver *agouti.WebDriver + +var _ = BeforeSuite(func() { + // Choose a WebDriver: + + agoutiDriver = agouti.PhantomJS() + // agoutiDriver = agouti.Selenium() + // agoutiDriver = agouti.ChromeDriver() + + Expect(agoutiDriver.Start()).To(Succeed()) +}) + +var _ = AfterSuite(func() { + Expect(agoutiDriver.Stop()).To(Succeed()) +}) +` + +type bootstrapData struct { + Package string + FormattedName string + GinkgoImport string + GomegaImport string +} + +func getPackageAndFormattedName() (string, string, string) { + path, err := os.Getwd() + if err != nil { + complainAndQuit("Could not get current working directory: \n" + err.Error()) + } + + dirName := strings.Replace(filepath.Base(path), "-", "_", -1) + dirName = strings.Replace(dirName, " ", "_", -1) + + pkg, err := build.ImportDir(path, 0) + packageName := pkg.Name + if err != nil { + packageName = dirName + } + + formattedName := prettifyPackageName(filepath.Base(path)) + return packageName, dirName, formattedName +} + +func prettifyPackageName(name string) string { + name = strings.Replace(name, "-", " ", -1) + name = strings.Replace(name, "_", " ", -1) + name = strings.Title(name) + name = strings.Replace(name, " ", "", -1) + return name +} + +func determinePackageName(name string, internal bool) string { + if internal { + return name + } + + return name + "_test" +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func generateBootstrap(agouti, noDot, internal bool, customBootstrapFile string) { + packageName, bootstrapFilePrefix, formattedName := getPackageAndFormattedName() + data := bootstrapData{ + Package: determinePackageName(packageName, internal), + FormattedName: formattedName, + GinkgoImport: `. "github.com/onsi/ginkgo"`, + GomegaImport: `. "github.com/onsi/gomega"`, + } + + if noDot { + data.GinkgoImport = `"github.com/onsi/ginkgo"` + data.GomegaImport = `"github.com/onsi/gomega"` + } + + targetFile := fmt.Sprintf("%s_suite_test.go", bootstrapFilePrefix) + if fileExists(targetFile) { + fmt.Printf("%s already exists.\n\n", targetFile) + os.Exit(1) + } else { + fmt.Printf("Generating ginkgo test suite bootstrap for %s in:\n\t%s\n", packageName, targetFile) + } + + f, err := os.Create(targetFile) + if err != nil { + complainAndQuit("Could not create file: " + err.Error()) + panic(err.Error()) + } + defer f.Close() + + var templateText string + if customBootstrapFile != "" { + tpl, err := ioutil.ReadFile(customBootstrapFile) + if err != nil { + panic(err.Error()) + } + templateText = string(tpl) + } else if agouti { + templateText = agoutiBootstrapText + } else { + templateText = bootstrapText + } + + bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Parse(templateText) + if err != nil { + panic(err.Error()) + } + + buf := &bytes.Buffer{} + bootstrapTemplate.Execute(buf, data) + + if noDot { + contents, err := nodot.ApplyNoDot(buf.Bytes()) + if err != nil { + complainAndQuit("Failed to import nodot declarations: " + err.Error()) + } + fmt.Println("To update the nodot declarations in the future, switch to this directory and run:\n\tginkgo nodot") + buf = bytes.NewBuffer(contents) + } + + buf.WriteTo(f) + + goFmt(targetFile) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/build_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/build_command.go new file mode 100644 index 00000000..2fddef0f --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/build_command.go @@ -0,0 +1,66 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + + "github.com/onsi/ginkgo/ginkgo/interrupthandler" + "github.com/onsi/ginkgo/ginkgo/testrunner" +) + +func BuildBuildCommand() *Command { + commandFlags := NewBuildCommandFlags(flag.NewFlagSet("build", flag.ExitOnError)) + interruptHandler := interrupthandler.NewInterruptHandler() + builder := &SpecBuilder{ + commandFlags: commandFlags, + interruptHandler: interruptHandler, + } + + return &Command{ + Name: "build", + FlagSet: commandFlags.FlagSet, + UsageCommand: "ginkgo build ", + Usage: []string{ + "Build the passed in (or the package in the current directory if left blank).", + "Accepts the following flags:", + }, + Command: builder.BuildSpecs, + } +} + +type SpecBuilder struct { + commandFlags *RunWatchAndBuildCommandFlags + interruptHandler *interrupthandler.InterruptHandler +} + +func (r *SpecBuilder) BuildSpecs(args []string, additionalArgs []string) { + r.commandFlags.computeNodes() + + suites, _ := findSuites(args, r.commandFlags.Recurse, r.commandFlags.SkipPackage, false) + + if len(suites) == 0 { + complainAndQuit("Found no test suites") + } + + passed := true + for _, suite := range suites { + runner := testrunner.New(suite, 1, false, 0, r.commandFlags.GoOpts, nil) + fmt.Printf("Compiling %s...\n", suite.PackageName) + + path, _ := filepath.Abs(filepath.Join(suite.Path, fmt.Sprintf("%s.test", suite.PackageName))) + err := runner.CompileTo(path) + if err != nil { + fmt.Println(err.Error()) + passed = false + } else { + fmt.Printf(" compiled %s.test\n", suite.PackageName) + } + } + + if passed { + os.Exit(0) + } + os.Exit(1) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/ginkgo_ast_nodes.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/ginkgo_ast_nodes.go new file mode 100644 index 00000000..02e2b3b3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/convert/ginkgo_ast_nodes.go @@ -0,0 +1,123 @@ +package convert + +import ( + "fmt" + "go/ast" + "strings" + "unicode" +) + +/* + * Creates a func init() node + */ +func createVarUnderscoreBlock() *ast.ValueSpec { + valueSpec := &ast.ValueSpec{} + object := &ast.Object{Kind: 4, Name: "_", Decl: valueSpec, Data: 0} + ident := &ast.Ident{Name: "_", Obj: object} + valueSpec.Names = append(valueSpec.Names, ident) + return valueSpec +} + +/* + * Creates a Describe("Testing with ginkgo", func() { }) node + */ +func createDescribeBlock() *ast.CallExpr { + blockStatement := &ast.BlockStmt{List: []ast.Stmt{}} + + fieldList := &ast.FieldList{} + funcType := &ast.FuncType{Params: fieldList} + funcLit := &ast.FuncLit{Type: funcType, Body: blockStatement} + basicLit := &ast.BasicLit{Kind: 9, Value: "\"Testing with Ginkgo\""} + describeIdent := &ast.Ident{Name: "Describe"} + return &ast.CallExpr{Fun: describeIdent, Args: []ast.Expr{basicLit, funcLit}} +} + +/* + * Convenience function to return the name of the *testing.T param + * for a Test function that will be rewritten. This is useful because + * we will want to replace the usage of this named *testing.T inside the + * body of the function with a GinktoT. + */ +func namedTestingTArg(node *ast.FuncDecl) string { + return node.Type.Params.List[0].Names[0].Name // *exhale* +} + +/* + * Convenience function to return the block statement node for a Describe statement + */ +func blockStatementFromDescribe(desc *ast.CallExpr) *ast.BlockStmt { + var funcLit *ast.FuncLit + var found = false + + for _, node := range desc.Args { + switch node := node.(type) { + case *ast.FuncLit: + found = true + funcLit = node + break + } + } + + if !found { + panic("Error finding ast.FuncLit inside describe statement. Somebody done goofed.") + } + + return funcLit.Body +} + +/* convenience function for creating an It("TestNameHere") + * with all the body of the test function inside the anonymous + * func passed to It() + */ +func createItStatementForTestFunc(testFunc *ast.FuncDecl) *ast.ExprStmt { + blockStatement := &ast.BlockStmt{List: testFunc.Body.List} + fieldList := &ast.FieldList{} + funcType := &ast.FuncType{Params: fieldList} + funcLit := &ast.FuncLit{Type: funcType, Body: blockStatement} + + testName := rewriteTestName(testFunc.Name.Name) + basicLit := &ast.BasicLit{Kind: 9, Value: fmt.Sprintf("\"%s\"", testName)} + itBlockIdent := &ast.Ident{Name: "It"} + callExpr := &ast.CallExpr{Fun: itBlockIdent, Args: []ast.Expr{basicLit, funcLit}} + return &ast.ExprStmt{X: callExpr} +} + +/* +* rewrite test names to be human readable +* eg: rewrites "TestSomethingAmazing" as "something amazing" + */ +func rewriteTestName(testName string) string { + nameComponents := []string{} + currentString := "" + indexOfTest := strings.Index(testName, "Test") + if indexOfTest != 0 { + return testName + } + + testName = strings.Replace(testName, "Test", "", 1) + first, rest := testName[0], testName[1:] + testName = string(unicode.ToLower(rune(first))) + rest + + for _, rune := range testName { + if unicode.IsUpper(rune) { + nameComponents = append(nameComponents, currentString) + currentString = string(unicode.ToLower(rune)) + } else { + currentString += string(rune) + } + } + + return strings.Join(append(nameComponents, currentString), " ") +} + +func newGinkgoTFromIdent(ident *ast.Ident) *ast.CallExpr { + return &ast.CallExpr{ + Lparen: ident.NamePos + 1, + Rparen: ident.NamePos + 2, + Fun: &ast.Ident{Name: "GinkgoT"}, + } +} + +func newGinkgoTInterface() *ast.Ident { + return &ast.Ident{Name: "GinkgoTInterface"} +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/import.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/import.go new file mode 100644 index 00000000..06c6ec94 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/convert/import.go @@ -0,0 +1,90 @@ +package convert + +import ( + "fmt" + "go/ast" +) + +/* + * Given the root node of an AST, returns the node containing the + * import statements for the file. + */ +func importsForRootNode(rootNode *ast.File) (imports *ast.GenDecl, err error) { + for _, declaration := range rootNode.Decls { + decl, ok := declaration.(*ast.GenDecl) + if !ok || len(decl.Specs) == 0 { + continue + } + + _, ok = decl.Specs[0].(*ast.ImportSpec) + if ok { + imports = decl + return + } + } + + err = fmt.Errorf("Could not find imports for root node:\n\t%#v\n", rootNode) + return +} + +/* + * Removes "testing" import, if present + */ +func removeTestingImport(rootNode *ast.File) { + importDecl, err := importsForRootNode(rootNode) + if err != nil { + panic(err.Error()) + } + + var index int + for i, importSpec := range importDecl.Specs { + importSpec := importSpec.(*ast.ImportSpec) + if importSpec.Path.Value == "\"testing\"" { + index = i + break + } + } + + importDecl.Specs = append(importDecl.Specs[:index], importDecl.Specs[index+1:]...) +} + +/* + * Adds import statements for onsi/ginkgo, if missing + */ +func addGinkgoImports(rootNode *ast.File) { + importDecl, err := importsForRootNode(rootNode) + if err != nil { + panic(err.Error()) + } + + if len(importDecl.Specs) == 0 { + // TODO: might need to create a import decl here + panic("unimplemented : expected to find an imports block") + } + + needsGinkgo := true + for _, importSpec := range importDecl.Specs { + importSpec, ok := importSpec.(*ast.ImportSpec) + if !ok { + continue + } + + if importSpec.Path.Value == "\"github.com/onsi/ginkgo\"" { + needsGinkgo = false + } + } + + if needsGinkgo { + importDecl.Specs = append(importDecl.Specs, createImport(".", "\"github.com/onsi/ginkgo\"")) + } +} + +/* + * convenience function to create an import statement + */ +func createImport(name, path string) *ast.ImportSpec { + return &ast.ImportSpec{ + Name: &ast.Ident{Name: name}, + Path: &ast.BasicLit{Kind: 9, Value: path}, + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/package_rewriter.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/package_rewriter.go new file mode 100644 index 00000000..363e52fe --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/convert/package_rewriter.go @@ -0,0 +1,128 @@ +package convert + +import ( + "fmt" + "go/build" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "regexp" +) + +/* + * RewritePackage takes a name (eg: my-package/tools), finds its test files using + * Go's build package, and then rewrites them. A ginkgo test suite file will + * also be added for this package, and all of its child packages. + */ +func RewritePackage(packageName string) { + pkg, err := packageWithName(packageName) + if err != nil { + panic(fmt.Sprintf("unexpected error reading package: '%s'\n%s\n", packageName, err.Error())) + } + + for _, filename := range findTestsInPackage(pkg) { + rewriteTestsInFile(filename) + } +} + +/* + * Given a package, findTestsInPackage reads the test files in the directory, + * and then recurses on each child package, returning a slice of all test files + * found in this process. + */ +func findTestsInPackage(pkg *build.Package) (testfiles []string) { + for _, file := range append(pkg.TestGoFiles, pkg.XTestGoFiles...) { + testfile, _ := filepath.Abs(filepath.Join(pkg.Dir, file)) + testfiles = append(testfiles, testfile) + } + + dirFiles, err := ioutil.ReadDir(pkg.Dir) + if err != nil { + panic(fmt.Sprintf("unexpected error reading dir: '%s'\n%s\n", pkg.Dir, err.Error())) + } + + re := regexp.MustCompile(`^[._]`) + + for _, file := range dirFiles { + if !file.IsDir() { + continue + } + + if re.Match([]byte(file.Name())) { + continue + } + + packageName := filepath.Join(pkg.ImportPath, file.Name()) + subPackage, err := packageWithName(packageName) + if err != nil { + panic(fmt.Sprintf("unexpected error reading package: '%s'\n%s\n", packageName, err.Error())) + } + + testfiles = append(testfiles, findTestsInPackage(subPackage)...) + } + + addGinkgoSuiteForPackage(pkg) + goFmtPackage(pkg) + return +} + +/* + * Shells out to `ginkgo bootstrap` to create a test suite file + */ +func addGinkgoSuiteForPackage(pkg *build.Package) { + originalDir, err := os.Getwd() + if err != nil { + panic(err) + } + + suite_test_file := filepath.Join(pkg.Dir, pkg.Name+"_suite_test.go") + + _, err = os.Stat(suite_test_file) + if err == nil { + return // test file already exists, this should be a no-op + } + + err = os.Chdir(pkg.Dir) + if err != nil { + panic(err) + } + + output, err := exec.Command("ginkgo", "bootstrap").Output() + + if err != nil { + panic(fmt.Sprintf("error running 'ginkgo bootstrap'.\nstdout: %s\n%s\n", output, err.Error())) + } + + err = os.Chdir(originalDir) + if err != nil { + panic(err) + } +} + +/* + * Shells out to `go fmt` to format the package + */ +func goFmtPackage(pkg *build.Package) { + path, _ := filepath.Abs(pkg.ImportPath) + output, err := exec.Command("go", "fmt", path).CombinedOutput() + + if err != nil { + fmt.Printf("Warning: Error running 'go fmt %s'.\nstdout: %s\n%s\n", path, output, err.Error()) + } +} + +/* + * Attempts to return a package with its test files already read. + * The ImportMode arg to build.Import lets you specify if you want go to read the + * buildable go files inside the package, but it fails if the package has no go files + */ +func packageWithName(name string) (pkg *build.Package, err error) { + pkg, err = build.Default.Import(name, ".", build.ImportMode(0)) + if err == nil { + return + } + + pkg, err = build.Default.Import(name, ".", build.ImportMode(1)) + return +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/test_finder.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/test_finder.go new file mode 100644 index 00000000..b33595c9 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/convert/test_finder.go @@ -0,0 +1,56 @@ +package convert + +import ( + "go/ast" + "regexp" +) + +/* + * Given a root node, walks its top level statements and returns + * points to function nodes to rewrite as It statements. + * These functions, according to Go testing convention, must be named + * TestWithCamelCasedName and receive a single *testing.T argument. + */ +func findTestFuncs(rootNode *ast.File) (testsToRewrite []*ast.FuncDecl) { + testNameRegexp := regexp.MustCompile("^Test[0-9A-Z].+") + + ast.Inspect(rootNode, func(node ast.Node) bool { + if node == nil { + return false + } + + switch node := node.(type) { + case *ast.FuncDecl: + matches := testNameRegexp.MatchString(node.Name.Name) + + if matches && receivesTestingT(node) { + testsToRewrite = append(testsToRewrite, node) + } + } + + return true + }) + + return +} + +/* + * convenience function that looks at args to a function and determines if its + * params include an argument of type *testing.T + */ +func receivesTestingT(node *ast.FuncDecl) bool { + if len(node.Type.Params.List) != 1 { + return false + } + + base, ok := node.Type.Params.List[0].Type.(*ast.StarExpr) + if !ok { + return false + } + + intermediate := base.X.(*ast.SelectorExpr) + isTestingPackage := intermediate.X.(*ast.Ident).Name == "testing" + isTestingT := intermediate.Sel.Name == "T" + + return isTestingPackage && isTestingT +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/testfile_rewriter.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/testfile_rewriter.go new file mode 100644 index 00000000..60c73504 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/convert/testfile_rewriter.go @@ -0,0 +1,162 @@ +package convert + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "io/ioutil" + "os" +) + +/* + * Given a file path, rewrites any tests in the Ginkgo format. + * First, we parse the AST, and update the imports declaration. + * Then, we walk the first child elements in the file, returning tests to rewrite. + * A top level init func is declared, with a single Describe func inside. + * Then the test functions to rewrite are inserted as It statements inside the Describe. + * Finally we walk the rest of the file, replacing other usages of *testing.T + * Once that is complete, we write the AST back out again to its file. + */ +func rewriteTestsInFile(pathToFile string) { + fileSet := token.NewFileSet() + rootNode, err := parser.ParseFile(fileSet, pathToFile, nil, parser.ParseComments) + if err != nil { + panic(fmt.Sprintf("Error parsing test file '%s':\n%s\n", pathToFile, err.Error())) + } + + addGinkgoImports(rootNode) + removeTestingImport(rootNode) + + varUnderscoreBlock := createVarUnderscoreBlock() + describeBlock := createDescribeBlock() + varUnderscoreBlock.Values = []ast.Expr{describeBlock} + + for _, testFunc := range findTestFuncs(rootNode) { + rewriteTestFuncAsItStatement(testFunc, rootNode, describeBlock) + } + + underscoreDecl := &ast.GenDecl{ + Tok: 85, // gah, magick numbers are needed to make this work + TokPos: 14, // this tricks Go into writing "var _ = Describe" + Specs: []ast.Spec{varUnderscoreBlock}, + } + + imports := rootNode.Decls[0] + tail := rootNode.Decls[1:] + rootNode.Decls = append(append([]ast.Decl{imports}, underscoreDecl), tail...) + rewriteOtherFuncsToUseGinkgoT(rootNode.Decls) + walkNodesInRootNodeReplacingTestingT(rootNode) + + var buffer bytes.Buffer + if err = format.Node(&buffer, fileSet, rootNode); err != nil { + panic(fmt.Sprintf("Error formatting ast node after rewriting tests.\n%s\n", err.Error())) + } + + fileInfo, err := os.Stat(pathToFile) + + if err != nil { + panic(fmt.Sprintf("Error stat'ing file: %s\n", pathToFile)) + } + + err = ioutil.WriteFile(pathToFile, buffer.Bytes(), fileInfo.Mode()) +} + +/* + * Given a test func named TestDoesSomethingNeat, rewrites it as + * It("does something neat", func() { __test_body_here__ }) and adds it + * to the Describe's list of statements + */ +func rewriteTestFuncAsItStatement(testFunc *ast.FuncDecl, rootNode *ast.File, describe *ast.CallExpr) { + var funcIndex int = -1 + for index, child := range rootNode.Decls { + if child == testFunc { + funcIndex = index + break + } + } + + if funcIndex < 0 { + panic(fmt.Sprintf("Assert failed: Error finding index for test node %s\n", testFunc.Name.Name)) + } + + var block *ast.BlockStmt = blockStatementFromDescribe(describe) + block.List = append(block.List, createItStatementForTestFunc(testFunc)) + replaceTestingTsWithGinkgoT(block, namedTestingTArg(testFunc)) + + // remove the old test func from the root node's declarations + rootNode.Decls = append(rootNode.Decls[:funcIndex], rootNode.Decls[funcIndex+1:]...) +} + +/* + * walks nodes inside of a test func's statements and replaces the usage of + * it's named *testing.T param with GinkgoT's + */ +func replaceTestingTsWithGinkgoT(statementsBlock *ast.BlockStmt, testingT string) { + ast.Inspect(statementsBlock, func(node ast.Node) bool { + if node == nil { + return false + } + + keyValueExpr, ok := node.(*ast.KeyValueExpr) + if ok { + replaceNamedTestingTsInKeyValueExpression(keyValueExpr, testingT) + return true + } + + funcLiteral, ok := node.(*ast.FuncLit) + if ok { + replaceTypeDeclTestingTsInFuncLiteral(funcLiteral) + return true + } + + callExpr, ok := node.(*ast.CallExpr) + if !ok { + return true + } + replaceTestingTsInArgsLists(callExpr, testingT) + + funCall, ok := callExpr.Fun.(*ast.SelectorExpr) + if ok { + replaceTestingTsMethodCalls(funCall, testingT) + } + + return true + }) +} + +/* + * rewrite t.Fail() or any other *testing.T method by replacing with T().Fail() + * This function receives a selector expression (eg: t.Fail()) and + * the name of the *testing.T param from the function declaration. Rewrites the + * selector expression in place if the target was a *testing.T + */ +func replaceTestingTsMethodCalls(selectorExpr *ast.SelectorExpr, testingT string) { + ident, ok := selectorExpr.X.(*ast.Ident) + if !ok { + return + } + + if ident.Name == testingT { + selectorExpr.X = newGinkgoTFromIdent(ident) + } +} + +/* + * replaces usages of a named *testing.T param inside of a call expression + * with a new GinkgoT object + */ +func replaceTestingTsInArgsLists(callExpr *ast.CallExpr, testingT string) { + for index, arg := range callExpr.Args { + ident, ok := arg.(*ast.Ident) + if !ok { + continue + } + + if ident.Name == testingT { + callExpr.Args[index] = newGinkgoTFromIdent(ident) + } + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/testing_t_rewriter.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/testing_t_rewriter.go new file mode 100644 index 00000000..418cdc4e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/convert/testing_t_rewriter.go @@ -0,0 +1,130 @@ +package convert + +import ( + "go/ast" +) + +/* + * Rewrites any other top level funcs that receive a *testing.T param + */ +func rewriteOtherFuncsToUseGinkgoT(declarations []ast.Decl) { + for _, decl := range declarations { + decl, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + + for _, param := range decl.Type.Params.List { + starExpr, ok := param.Type.(*ast.StarExpr) + if !ok { + continue + } + + selectorExpr, ok := starExpr.X.(*ast.SelectorExpr) + if !ok { + continue + } + + xIdent, ok := selectorExpr.X.(*ast.Ident) + if !ok || xIdent.Name != "testing" { + continue + } + + if selectorExpr.Sel.Name != "T" { + continue + } + + param.Type = newGinkgoTInterface() + } + } +} + +/* + * Walks all of the nodes in the file, replacing *testing.T in struct + * and func literal nodes. eg: + * type foo struct { *testing.T } + * var bar = func(t *testing.T) { } + */ +func walkNodesInRootNodeReplacingTestingT(rootNode *ast.File) { + ast.Inspect(rootNode, func(node ast.Node) bool { + if node == nil { + return false + } + + switch node := node.(type) { + case *ast.StructType: + replaceTestingTsInStructType(node) + case *ast.FuncLit: + replaceTypeDeclTestingTsInFuncLiteral(node) + } + + return true + }) +} + +/* + * replaces named *testing.T inside a composite literal + */ +func replaceNamedTestingTsInKeyValueExpression(kve *ast.KeyValueExpr, testingT string) { + ident, ok := kve.Value.(*ast.Ident) + if !ok { + return + } + + if ident.Name == testingT { + kve.Value = newGinkgoTFromIdent(ident) + } +} + +/* + * replaces *testing.T params in a func literal with GinkgoT + */ +func replaceTypeDeclTestingTsInFuncLiteral(functionLiteral *ast.FuncLit) { + for _, arg := range functionLiteral.Type.Params.List { + starExpr, ok := arg.Type.(*ast.StarExpr) + if !ok { + continue + } + + selectorExpr, ok := starExpr.X.(*ast.SelectorExpr) + if !ok { + continue + } + + target, ok := selectorExpr.X.(*ast.Ident) + if !ok { + continue + } + + if target.Name == "testing" && selectorExpr.Sel.Name == "T" { + arg.Type = newGinkgoTInterface() + } + } +} + +/* + * Replaces *testing.T types inside of a struct declaration with a GinkgoT + * eg: type foo struct { *testing.T } + */ +func replaceTestingTsInStructType(structType *ast.StructType) { + for _, field := range structType.Fields.List { + starExpr, ok := field.Type.(*ast.StarExpr) + if !ok { + continue + } + + selectorExpr, ok := starExpr.X.(*ast.SelectorExpr) + if !ok { + continue + } + + xIdent, ok := selectorExpr.X.(*ast.Ident) + if !ok { + continue + } + + if xIdent.Name == "testing" && selectorExpr.Sel.Name == "T" { + field.Type = newGinkgoTInterface() + } + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert_command.go new file mode 100644 index 00000000..8e99f56a --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/convert_command.go @@ -0,0 +1,51 @@ +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/onsi/ginkgo/ginkgo/convert" + colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" + "github.com/onsi/ginkgo/types" +) + +func BuildConvertCommand() *Command { + return &Command{ + Name: "convert", + FlagSet: flag.NewFlagSet("convert", flag.ExitOnError), + UsageCommand: "ginkgo convert /path/to/package", + Usage: []string{ + "Convert the package at the passed in path from an XUnit-style test to a Ginkgo-style test", + }, + Command: convertPackage, + } +} + +func convertPackage(args []string, additionalArgs []string) { + deprecationTracker := types.NewDeprecationTracker() + deprecationTracker.TrackDeprecation(types.Deprecations.Convert()) + fmt.Fprintln(colorable.NewColorableStderr(), deprecationTracker.DeprecationsReport()) + + if len(args) != 1 { + println(fmt.Sprintf("usage: ginkgo convert /path/to/your/package")) + os.Exit(1) + } + + defer func() { + err := recover() + if err != nil { + switch err := err.(type) { + case error: + println(err.Error()) + case string: + println(err) + default: + println(fmt.Sprintf("unexpected error: %#v", err)) + } + os.Exit(1) + } + }() + + convert.RewritePackage(args[0]) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/generate_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/generate_command.go new file mode 100644 index 00000000..27758beb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/generate_command.go @@ -0,0 +1,273 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + "text/template" + + sprig "github.com/go-task/slim-sprig" +) + +func BuildGenerateCommand() *Command { + var ( + agouti, noDot, internal bool + customTestFile string + ) + flagSet := flag.NewFlagSet("generate", flag.ExitOnError) + flagSet.BoolVar(&agouti, "agouti", false, "If set, generate will generate a test file for writing Agouti tests") + flagSet.BoolVar(&noDot, "nodot", false, "If set, generate will generate a test file that does not . import ginkgo and gomega") + flagSet.BoolVar(&internal, "internal", false, "If set, generate will generate a test file that uses the regular package name") + flagSet.StringVar(&customTestFile, "template", "", "If specified, generate will use the contents of the file passed as the test file template") + + return &Command{ + Name: "generate", + FlagSet: flagSet, + UsageCommand: "ginkgo generate ", + Usage: []string{ + "Generate a test file named filename_test.go", + "If the optional argument is omitted, a file named after the package in the current directory will be created.", + "Accepts the following flags:", + }, + Command: func(args []string, additionalArgs []string) { + generateSpec(args, agouti, noDot, internal, customTestFile) + }, + } +} + +var specText = `package {{.Package}} + +import ( + {{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}} + {{if .IncludeImports}}. "github.com/onsi/gomega"{{end}} + + {{if .ImportPackage}}"{{.PackageImportPath}}"{{end}} +) + +var _ = Describe("{{.Subject}}", func() { + +}) +` + +var agoutiSpecText = `package {{.Package}} + +import ( + {{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}} + {{if .IncludeImports}}. "github.com/onsi/gomega"{{end}} + "github.com/sclevine/agouti" + . "github.com/sclevine/agouti/matchers" + + {{if .ImportPackage}}"{{.PackageImportPath}}"{{end}} +) + +var _ = Describe("{{.Subject}}", func() { + var page *agouti.Page + + BeforeEach(func() { + var err error + page, err = agoutiDriver.NewPage() + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(page.Destroy()).To(Succeed()) + }) +}) +` + +type specData struct { + Package string + Subject string + PackageImportPath string + IncludeImports bool + ImportPackage bool +} + +func generateSpec(args []string, agouti, noDot, internal bool, customTestFile string) { + if len(args) == 0 { + err := generateSpecForSubject("", agouti, noDot, internal, customTestFile) + if err != nil { + fmt.Println(err.Error()) + fmt.Println("") + os.Exit(1) + } + fmt.Println("") + return + } + + var failed bool + for _, arg := range args { + err := generateSpecForSubject(arg, agouti, noDot, internal, customTestFile) + if err != nil { + failed = true + fmt.Println(err.Error()) + } + } + fmt.Println("") + if failed { + os.Exit(1) + } +} + +func generateSpecForSubject(subject string, agouti, noDot, internal bool, customTestFile string) error { + packageName, specFilePrefix, formattedName := getPackageAndFormattedName() + if subject != "" { + specFilePrefix = formatSubject(subject) + formattedName = prettifyPackageName(specFilePrefix) + } + + if internal { + specFilePrefix = specFilePrefix + "_internal" + } + + data := specData{ + Package: determinePackageName(packageName, internal), + Subject: formattedName, + PackageImportPath: getPackageImportPath(), + IncludeImports: !noDot, + ImportPackage: !internal, + } + + targetFile := fmt.Sprintf("%s_test.go", specFilePrefix) + if fileExists(targetFile) { + return fmt.Errorf("%s already exists.", targetFile) + } else { + fmt.Printf("Generating ginkgo test for %s in:\n %s\n", data.Subject, targetFile) + } + + f, err := os.Create(targetFile) + if err != nil { + return err + } + defer f.Close() + + var templateText string + if customTestFile != "" { + tpl, err := ioutil.ReadFile(customTestFile) + if err != nil { + panic(err.Error()) + } + templateText = string(tpl) + } else if agouti { + templateText = agoutiSpecText + } else { + templateText = specText + } + + specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Parse(templateText) + if err != nil { + return err + } + + specTemplate.Execute(f, data) + goFmt(targetFile) + return nil +} + +func formatSubject(name string) string { + name = strings.Replace(name, "-", "_", -1) + name = strings.Replace(name, " ", "_", -1) + name = strings.Split(name, ".go")[0] + name = strings.Split(name, "_test")[0] + return name +} + +// moduleName returns module name from go.mod from given module root directory +func moduleName(modRoot string) string { + modFile, err := os.Open(filepath.Join(modRoot, "go.mod")) + if err != nil { + return "" + } + + mod := make([]byte, 128) + _, err = modFile.Read(mod) + if err != nil { + return "" + } + + slashSlash := []byte("//") + moduleStr := []byte("module") + + for len(mod) > 0 { + line := mod + mod = nil + if i := bytes.IndexByte(line, '\n'); i >= 0 { + line, mod = line[:i], line[i+1:] + } + if i := bytes.Index(line, slashSlash); i >= 0 { + line = line[:i] + } + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, moduleStr) { + continue + } + line = line[len(moduleStr):] + n := len(line) + line = bytes.TrimSpace(line) + if len(line) == n || len(line) == 0 { + continue + } + + if line[0] == '"' || line[0] == '`' { + p, err := strconv.Unquote(string(line)) + if err != nil { + return "" // malformed quoted string or multiline module path + } + return p + } + + return string(line) + } + + return "" // missing module path +} + +func findModuleRoot(dir string) (root string) { + dir = filepath.Clean(dir) + + // Look for enclosing go.mod. + for { + if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() { + return dir + } + d := filepath.Dir(dir) + if d == dir { + break + } + dir = d + } + return "" +} + +func getPackageImportPath() string { + workingDir, err := os.Getwd() + if err != nil { + panic(err.Error()) + } + + sep := string(filepath.Separator) + + // Try go.mod file first + modRoot := findModuleRoot(workingDir) + if modRoot != "" { + modName := moduleName(modRoot) + if modName != "" { + cd := strings.Replace(workingDir, modRoot, "", -1) + cd = strings.ReplaceAll(cd, sep, "/") + return modName + cd + } + } + + // Fallback to GOPATH structure + paths := strings.Split(workingDir, sep+"src"+sep) + if len(paths) == 1 { + fmt.Printf("\nCouldn't identify package import path.\n\n\tginkgo generate\n\nMust be run within a package directory under $GOPATH/src/...\nYou're going to have to change UNKNOWN_PACKAGE_PATH in the generated file...\n\n") + return "UNKNOWN_PACKAGE_PATH" + } + return filepath.ToSlash(paths[len(paths)-1]) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/help_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/help_command.go new file mode 100644 index 00000000..23b1d2f1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/help_command.go @@ -0,0 +1,31 @@ +package main + +import ( + "flag" + "fmt" +) + +func BuildHelpCommand() *Command { + return &Command{ + Name: "help", + FlagSet: flag.NewFlagSet("help", flag.ExitOnError), + UsageCommand: "ginkgo help ", + Usage: []string{ + "Print usage information. If a command is passed in, print usage information just for that command.", + }, + Command: printHelp, + } +} + +func printHelp(args []string, additionalArgs []string) { + if len(args) == 0 { + usage() + } else { + command, found := commandMatching(args[0]) + if !found { + complainAndQuit(fmt.Sprintf("Unknown command: %s", args[0])) + } + + usageForCommand(command, true) + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/interrupt_handler.go b/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/interrupt_handler.go new file mode 100644 index 00000000..ec456bf2 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/interrupt_handler.go @@ -0,0 +1,52 @@ +package interrupthandler + +import ( + "os" + "os/signal" + "sync" + "syscall" +) + +type InterruptHandler struct { + interruptCount int + lock *sync.Mutex + C chan bool +} + +func NewInterruptHandler() *InterruptHandler { + h := &InterruptHandler{ + lock: &sync.Mutex{}, + C: make(chan bool), + } + + go h.handleInterrupt() + SwallowSigQuit() + + return h +} + +func (h *InterruptHandler) WasInterrupted() bool { + h.lock.Lock() + defer h.lock.Unlock() + + return h.interruptCount > 0 +} + +func (h *InterruptHandler) handleInterrupt() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + <-c + signal.Stop(c) + + h.lock.Lock() + h.interruptCount++ + if h.interruptCount == 1 { + close(h.C) + } else if h.interruptCount > 5 { + os.Exit(1) + } + h.lock.Unlock() + + go h.handleInterrupt() +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_unix.go b/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_unix.go new file mode 100644 index 00000000..43c18544 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_unix.go @@ -0,0 +1,14 @@ +// +build freebsd openbsd netbsd dragonfly darwin linux solaris + +package interrupthandler + +import ( + "os" + "os/signal" + "syscall" +) + +func SwallowSigQuit() { + c := make(chan os.Signal, 1024) + signal.Notify(c, syscall.SIGQUIT) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_windows.go b/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_windows.go new file mode 100644 index 00000000..7f4a50e1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_windows.go @@ -0,0 +1,7 @@ +// +build windows + +package interrupthandler + +func SwallowSigQuit() { + //noop +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/main.go b/vendor/github.com/onsi/ginkgo/ginkgo/main.go new file mode 100644 index 00000000..ac725bf4 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/main.go @@ -0,0 +1,308 @@ +/* +The Ginkgo CLI + +The Ginkgo CLI is fully documented [here](http://onsi.github.io/ginkgo/#the_ginkgo_cli) + +You can also learn more by running: + + ginkgo help + +Here are some of the more commonly used commands: + +To install: + + go install github.com/onsi/ginkgo/ginkgo + +To run tests: + + ginkgo + +To run tests in all subdirectories: + + ginkgo -r + +To run tests in particular packages: + + ginkgo /path/to/package /path/to/another/package + +To pass arguments/flags to your tests: + + ginkgo -- + +To run tests in parallel + + ginkgo -p + +this will automatically detect the optimal number of nodes to use. Alternatively, you can specify the number of nodes with: + + ginkgo -nodes=N + +(note that you don't need to provide -p in this case). + +By default the Ginkgo CLI will spin up a server that the individual test processes send test output to. The CLI aggregates this output and then presents coherent test output, one test at a time, as each test completes. +An alternative is to have the parallel nodes run and stream interleaved output back. This useful for debugging, particularly in contexts where tests hang/fail to start. To get this interleaved output: + + ginkgo -nodes=N -stream=true + +On windows, the default value for stream is true. + +By default, when running multiple tests (with -r or a list of packages) Ginkgo will abort when a test fails. To have Ginkgo run subsequent test suites instead you can: + + ginkgo -keepGoing + +To fail if there are ginkgo tests in a directory but no test suite (missing `RunSpecs`) + + ginkgo -requireSuite + +To monitor packages and rerun tests when changes occur: + + ginkgo watch <-r> + +passing `ginkgo watch` the `-r` flag will recursively detect all test suites under the current directory and monitor them. +`watch` does not detect *new* packages. Moreover, changes in package X only rerun the tests for package X, tests for packages +that depend on X are not rerun. + +[OSX & Linux only] To receive (desktop) notifications when a test run completes: + + ginkgo -notify + +this is particularly useful with `ginkgo watch`. Notifications are currently only supported on OS X and require that you `brew install terminal-notifier` + +Sometimes (to suss out race conditions/flakey tests, for example) you want to keep running a test suite until it fails. You can do this with: + + ginkgo -untilItFails + +To bootstrap a test suite: + + ginkgo bootstrap + +To generate a test file: + + ginkgo generate + +To bootstrap/generate test files without using "." imports: + + ginkgo bootstrap --nodot + ginkgo generate --nodot + +this will explicitly export all the identifiers in Ginkgo and Gomega allowing you to rename them to avoid collisions. When you pull to the latest Ginkgo/Gomega you'll want to run + + ginkgo nodot + +to refresh this list and pull in any new identifiers. In particular, this will pull in any new Gomega matchers that get added. + +To convert an existing XUnit style test suite to a Ginkgo-style test suite: + + ginkgo convert . + +To unfocus tests: + + ginkgo unfocus + +or + + ginkgo blur + +To compile a test suite: + + ginkgo build + +will output an executable file named `package.test`. This can be run directly or by invoking + + ginkgo + + +To print an outline of Ginkgo specs and containers in a file: + + gingko outline + +To print out Ginkgo's version: + + ginkgo version + +To get more help: + + ginkgo help +*/ +package main + +import ( + "flag" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/ginkgo/testsuite" +) + +const greenColor = "\x1b[32m" +const redColor = "\x1b[91m" +const defaultStyle = "\x1b[0m" +const lightGrayColor = "\x1b[37m" + +type Command struct { + Name string + AltName string + FlagSet *flag.FlagSet + Usage []string + UsageCommand string + Command func(args []string, additionalArgs []string) + SuppressFlagDocumentation bool + FlagDocSubstitute []string +} + +func (c *Command) Matches(name string) bool { + return c.Name == name || (c.AltName != "" && c.AltName == name) +} + +func (c *Command) Run(args []string, additionalArgs []string) { + c.FlagSet.Usage = usage + c.FlagSet.Parse(args) + c.Command(c.FlagSet.Args(), additionalArgs) +} + +var DefaultCommand *Command +var Commands []*Command + +func init() { + DefaultCommand = BuildRunCommand() + Commands = append(Commands, BuildWatchCommand()) + Commands = append(Commands, BuildBuildCommand()) + Commands = append(Commands, BuildBootstrapCommand()) + Commands = append(Commands, BuildGenerateCommand()) + Commands = append(Commands, BuildNodotCommand()) + Commands = append(Commands, BuildConvertCommand()) + Commands = append(Commands, BuildUnfocusCommand()) + Commands = append(Commands, BuildVersionCommand()) + Commands = append(Commands, BuildHelpCommand()) + Commands = append(Commands, BuildOutlineCommand()) +} + +func main() { + args := []string{} + additionalArgs := []string{} + + foundDelimiter := false + + for _, arg := range os.Args[1:] { + if !foundDelimiter { + if arg == "--" { + foundDelimiter = true + continue + } + } + + if foundDelimiter { + additionalArgs = append(additionalArgs, arg) + } else { + args = append(args, arg) + } + } + + if len(args) > 0 { + commandToRun, found := commandMatching(args[0]) + if found { + commandToRun.Run(args[1:], additionalArgs) + return + } + } + + DefaultCommand.Run(args, additionalArgs) +} + +func commandMatching(name string) (*Command, bool) { + for _, command := range Commands { + if command.Matches(name) { + return command, true + } + } + return nil, false +} + +func usage() { + fmt.Printf("Ginkgo Version %s\n\n", config.VERSION) + usageForCommand(DefaultCommand, false) + for _, command := range Commands { + fmt.Printf("\n") + usageForCommand(command, false) + } +} + +func usageForCommand(command *Command, longForm bool) { + fmt.Printf("%s\n%s\n", command.UsageCommand, strings.Repeat("-", len(command.UsageCommand))) + fmt.Printf("%s\n", strings.Join(command.Usage, "\n")) + if command.SuppressFlagDocumentation && !longForm { + fmt.Printf("%s\n", strings.Join(command.FlagDocSubstitute, "\n ")) + } else { + command.FlagSet.SetOutput(os.Stdout) + command.FlagSet.PrintDefaults() + } +} + +func complainAndQuit(complaint string) { + fmt.Fprintf(os.Stderr, "%s\nFor usage instructions:\n\tginkgo help\n", complaint) + os.Exit(1) +} + +func findSuites(args []string, recurseForAll bool, skipPackage string, allowPrecompiled bool) ([]testsuite.TestSuite, []string) { + suites := []testsuite.TestSuite{} + + if len(args) > 0 { + for _, arg := range args { + if allowPrecompiled { + suite, err := testsuite.PrecompiledTestSuite(arg) + if err == nil { + suites = append(suites, suite) + continue + } + } + recurseForSuite := recurseForAll + if strings.HasSuffix(arg, "/...") && arg != "/..." { + arg = arg[:len(arg)-4] + recurseForSuite = true + } + suites = append(suites, testsuite.SuitesInDir(arg, recurseForSuite)...) + } + } else { + suites = testsuite.SuitesInDir(".", recurseForAll) + } + + skippedPackages := []string{} + if skipPackage != "" { + skipFilters := strings.Split(skipPackage, ",") + filteredSuites := []testsuite.TestSuite{} + for _, suite := range suites { + skip := false + for _, skipFilter := range skipFilters { + if strings.Contains(suite.Path, skipFilter) { + skip = true + break + } + } + if skip { + skippedPackages = append(skippedPackages, suite.Path) + } else { + filteredSuites = append(filteredSuites, suite) + } + } + suites = filteredSuites + } + + return suites, skippedPackages +} + +func goFmt(path string) { + out, err := exec.Command("go", "fmt", path).CombinedOutput() + if err != nil { + complainAndQuit("Could not fmt: " + err.Error() + "\n" + string(out)) + } +} + +func pluralizedWord(singular, plural string, count int) string { + if count == 1 { + return singular + } + return plural +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/nodot/nodot.go b/vendor/github.com/onsi/ginkgo/ginkgo/nodot/nodot.go new file mode 100644 index 00000000..c87b7216 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/nodot/nodot.go @@ -0,0 +1,196 @@ +package nodot + +import ( + "fmt" + "go/ast" + "go/build" + "go/parser" + "go/token" + "path/filepath" + "strings" +) + +func ApplyNoDot(data []byte) ([]byte, error) { + sections, err := generateNodotSections() + if err != nil { + return nil, err + } + + for _, section := range sections { + data = section.createOrUpdateIn(data) + } + + return data, nil +} + +type nodotSection struct { + name string + pkg string + declarations []string + types []string +} + +func (s nodotSection) createOrUpdateIn(data []byte) []byte { + renames := map[string]string{} + + contents := string(data) + + lines := strings.Split(contents, "\n") + + comment := "// Declarations for " + s.name + + newLines := []string{} + for _, line := range lines { + if line == comment { + continue + } + + words := strings.Split(line, " ") + lastWord := words[len(words)-1] + + if s.containsDeclarationOrType(lastWord) { + renames[lastWord] = words[1] + continue + } + + newLines = append(newLines, line) + } + + if len(newLines[len(newLines)-1]) > 0 { + newLines = append(newLines, "") + } + + newLines = append(newLines, comment) + + for _, typ := range s.types { + name, ok := renames[s.prefix(typ)] + if !ok { + name = typ + } + newLines = append(newLines, fmt.Sprintf("type %s %s", name, s.prefix(typ))) + } + + for _, decl := range s.declarations { + name, ok := renames[s.prefix(decl)] + if !ok { + name = decl + } + newLines = append(newLines, fmt.Sprintf("var %s = %s", name, s.prefix(decl))) + } + + newLines = append(newLines, "") + + newContents := strings.Join(newLines, "\n") + + return []byte(newContents) +} + +func (s nodotSection) prefix(declOrType string) string { + return s.pkg + "." + declOrType +} + +func (s nodotSection) containsDeclarationOrType(word string) bool { + for _, declaration := range s.declarations { + if s.prefix(declaration) == word { + return true + } + } + + for _, typ := range s.types { + if s.prefix(typ) == word { + return true + } + } + + return false +} + +func generateNodotSections() ([]nodotSection, error) { + sections := []nodotSection{} + + declarations, err := getExportedDeclerationsForPackage("github.com/onsi/ginkgo", "ginkgo_dsl.go", "GINKGO_VERSION", "GINKGO_PANIC") + if err != nil { + return nil, err + } + sections = append(sections, nodotSection{ + name: "Ginkgo DSL", + pkg: "ginkgo", + declarations: declarations, + types: []string{"Done", "Benchmarker"}, + }) + + declarations, err = getExportedDeclerationsForPackage("github.com/onsi/gomega", "gomega_dsl.go", "GOMEGA_VERSION") + if err != nil { + return nil, err + } + sections = append(sections, nodotSection{ + name: "Gomega DSL", + pkg: "gomega", + declarations: declarations, + }) + + declarations, err = getExportedDeclerationsForPackage("github.com/onsi/gomega", "matchers.go") + if err != nil { + return nil, err + } + sections = append(sections, nodotSection{ + name: "Gomega Matchers", + pkg: "gomega", + declarations: declarations, + }) + + return sections, nil +} + +func getExportedDeclerationsForPackage(pkgPath string, filename string, blacklist ...string) ([]string, error) { + pkg, err := build.Import(pkgPath, ".", 0) + if err != nil { + return []string{}, err + } + + declarations, err := getExportedDeclarationsForFile(filepath.Join(pkg.Dir, filename)) + if err != nil { + return []string{}, err + } + + blacklistLookup := map[string]bool{} + for _, declaration := range blacklist { + blacklistLookup[declaration] = true + } + + filteredDeclarations := []string{} + for _, declaration := range declarations { + if blacklistLookup[declaration] { + continue + } + filteredDeclarations = append(filteredDeclarations, declaration) + } + + return filteredDeclarations, nil +} + +func getExportedDeclarationsForFile(path string) ([]string, error) { + fset := token.NewFileSet() + tree, err := parser.ParseFile(fset, path, nil, 0) + if err != nil { + return []string{}, err + } + + declarations := []string{} + ast.FileExports(tree) + for _, decl := range tree.Decls { + switch x := decl.(type) { + case *ast.GenDecl: + switch s := x.Specs[0].(type) { + case *ast.ValueSpec: + declarations = append(declarations, s.Names[0].Name) + } + case *ast.FuncDecl: + if x.Recv == nil { + declarations = append(declarations, x.Name.Name) + } + } + } + + return declarations, nil +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/nodot_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/nodot_command.go new file mode 100644 index 00000000..39b88b5d --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/nodot_command.go @@ -0,0 +1,77 @@ +package main + +import ( + "bufio" + "flag" + "io/ioutil" + "os" + "path/filepath" + "regexp" + + "github.com/onsi/ginkgo/ginkgo/nodot" +) + +func BuildNodotCommand() *Command { + return &Command{ + Name: "nodot", + FlagSet: flag.NewFlagSet("bootstrap", flag.ExitOnError), + UsageCommand: "ginkgo nodot", + Usage: []string{ + "Update the nodot declarations in your test suite", + "Any missing declarations (from, say, a recently added matcher) will be added to your bootstrap file.", + "If you've renamed a declaration, that name will be honored and not overwritten.", + }, + Command: updateNodot, + } +} + +func updateNodot(args []string, additionalArgs []string) { + suiteFile, perm := findSuiteFile() + + data, err := ioutil.ReadFile(suiteFile) + if err != nil { + complainAndQuit("Failed to update nodot declarations: " + err.Error()) + } + + content, err := nodot.ApplyNoDot(data) + if err != nil { + complainAndQuit("Failed to update nodot declarations: " + err.Error()) + } + ioutil.WriteFile(suiteFile, content, perm) + + goFmt(suiteFile) +} + +func findSuiteFile() (string, os.FileMode) { + workingDir, err := os.Getwd() + if err != nil { + complainAndQuit("Could not find suite file for nodot: " + err.Error()) + } + + files, err := ioutil.ReadDir(workingDir) + if err != nil { + complainAndQuit("Could not find suite file for nodot: " + err.Error()) + } + + re := regexp.MustCompile(`RunSpecs\(|RunSpecsWithDefaultAndCustomReporters\(|RunSpecsWithCustomReporters\(`) + + for _, file := range files { + if file.IsDir() { + continue + } + path := filepath.Join(workingDir, file.Name()) + f, err := os.Open(path) + if err != nil { + complainAndQuit("Could not find suite file for nodot: " + err.Error()) + } + defer f.Close() + + if re.MatchReader(bufio.NewReader(f)) { + return path, file.Mode() + } + } + + complainAndQuit("Could not find a suite file for nodot: you need a bootstrap file that call's Ginkgo's RunSpecs() command.\nTry running ginkgo bootstrap first.") + + return "", 0 +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/notifications.go b/vendor/github.com/onsi/ginkgo/ginkgo/notifications.go new file mode 100644 index 00000000..368d61fb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/notifications.go @@ -0,0 +1,141 @@ +package main + +import ( + "fmt" + "os" + "os/exec" + "regexp" + "runtime" + "strings" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/ginkgo/testsuite" +) + +type Notifier struct { + commandFlags *RunWatchAndBuildCommandFlags +} + +func NewNotifier(commandFlags *RunWatchAndBuildCommandFlags) *Notifier { + return &Notifier{ + commandFlags: commandFlags, + } +} + +func (n *Notifier) VerifyNotificationsAreAvailable() { + if n.commandFlags.Notify { + onLinux := (runtime.GOOS == "linux") + onOSX := (runtime.GOOS == "darwin") + if onOSX { + + _, err := exec.LookPath("terminal-notifier") + if err != nil { + fmt.Printf(`--notify requires terminal-notifier, which you don't seem to have installed. + +OSX: + +To remedy this: + + brew install terminal-notifier + +To learn more about terminal-notifier: + + https://github.com/alloy/terminal-notifier +`) + os.Exit(1) + } + + } else if onLinux { + + _, err := exec.LookPath("notify-send") + if err != nil { + fmt.Printf(`--notify requires terminal-notifier or notify-send, which you don't seem to have installed. + +Linux: + +Download and install notify-send for your distribution +`) + os.Exit(1) + } + + } + } +} + +func (n *Notifier) SendSuiteCompletionNotification(suite testsuite.TestSuite, suitePassed bool) { + if suitePassed { + n.SendNotification("Ginkgo [PASS]", fmt.Sprintf(`Test suite for "%s" passed.`, suite.PackageName)) + } else { + n.SendNotification("Ginkgo [FAIL]", fmt.Sprintf(`Test suite for "%s" failed.`, suite.PackageName)) + } +} + +func (n *Notifier) SendNotification(title string, subtitle string) { + + if n.commandFlags.Notify { + onLinux := (runtime.GOOS == "linux") + onOSX := (runtime.GOOS == "darwin") + + if onOSX { + + _, err := exec.LookPath("terminal-notifier") + if err == nil { + args := []string{"-title", title, "-subtitle", subtitle, "-group", "com.onsi.ginkgo"} + terminal := os.Getenv("TERM_PROGRAM") + if terminal == "iTerm.app" { + args = append(args, "-activate", "com.googlecode.iterm2") + } else if terminal == "Apple_Terminal" { + args = append(args, "-activate", "com.apple.Terminal") + } + + exec.Command("terminal-notifier", args...).Run() + } + + } else if onLinux { + + _, err := exec.LookPath("notify-send") + if err == nil { + args := []string{"-a", "ginkgo", title, subtitle} + exec.Command("notify-send", args...).Run() + } + + } + } +} + +func (n *Notifier) RunCommand(suite testsuite.TestSuite, suitePassed bool) { + + command := n.commandFlags.AfterSuiteHook + if command != "" { + + // Allow for string replacement to pass input to the command + passed := "[FAIL]" + if suitePassed { + passed = "[PASS]" + } + command = strings.Replace(command, "(ginkgo-suite-passed)", passed, -1) + command = strings.Replace(command, "(ginkgo-suite-name)", suite.PackageName, -1) + + // Must break command into parts + splitArgs := regexp.MustCompile(`'.+'|".+"|\S+`) + parts := splitArgs.FindAllString(command, -1) + + output, err := exec.Command(parts[0], parts[1:]...).CombinedOutput() + if err != nil { + fmt.Println("Post-suite command failed:") + if config.DefaultReporterConfig.NoColor { + fmt.Printf("\t%s\n", output) + } else { + fmt.Printf("\t%s%s%s\n", redColor, string(output), defaultStyle) + } + n.SendNotification("Ginkgo [ERROR]", fmt.Sprintf(`After suite command "%s" failed`, n.commandFlags.AfterSuiteHook)) + } else { + fmt.Println("Post-suite command succeeded:") + if config.DefaultReporterConfig.NoColor { + fmt.Printf("\t%s\n", output) + } else { + fmt.Printf("\t%s%s%s\n", greenColor, string(output), defaultStyle) + } + } + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/outline/ginkgo.go b/vendor/github.com/onsi/ginkgo/ginkgo/outline/ginkgo.go new file mode 100644 index 00000000..ce6b7fcd --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/outline/ginkgo.go @@ -0,0 +1,243 @@ +package outline + +import ( + "go/ast" + "go/token" + "strconv" +) + +const ( + // undefinedTextAlt is used if the spec/container text cannot be derived + undefinedTextAlt = "undefined" +) + +// ginkgoMetadata holds useful bits of information for every entry in the outline +type ginkgoMetadata struct { + // Name is the spec or container function name, e.g. `Describe` or `It` + Name string `json:"name"` + + // Text is the `text` argument passed to specs, and some containers + Text string `json:"text"` + + // Start is the position of first character of the spec or container block + Start int `json:"start"` + + // End is the position of first character immediately after the spec or container block + End int `json:"end"` + + Spec bool `json:"spec"` + Focused bool `json:"focused"` + Pending bool `json:"pending"` +} + +// ginkgoNode is used to construct the outline as a tree +type ginkgoNode struct { + ginkgoMetadata + Nodes []*ginkgoNode `json:"nodes"` +} + +type walkFunc func(n *ginkgoNode) + +func (n *ginkgoNode) PreOrder(f walkFunc) { + f(n) + for _, m := range n.Nodes { + m.PreOrder(f) + } +} + +func (n *ginkgoNode) PostOrder(f walkFunc) { + for _, m := range n.Nodes { + m.PostOrder(f) + } + f(n) +} + +func (n *ginkgoNode) Walk(pre, post walkFunc) { + pre(n) + for _, m := range n.Nodes { + m.Walk(pre, post) + } + post(n) +} + +// PropagateInheritedProperties propagates the Pending and Focused properties +// through the subtree rooted at n. +func (n *ginkgoNode) PropagateInheritedProperties() { + n.PreOrder(func(thisNode *ginkgoNode) { + for _, descendantNode := range thisNode.Nodes { + if thisNode.Pending { + descendantNode.Pending = true + descendantNode.Focused = false + } + if thisNode.Focused && !descendantNode.Pending { + descendantNode.Focused = true + } + } + }) +} + +// BackpropagateUnfocus propagates the Focused property through the subtree +// rooted at n. It applies the rule described in the Ginkgo docs: +// > Nested programmatically focused specs follow a simple rule: if a +// > leaf-node is marked focused, any of its ancestor nodes that are marked +// > focus will be unfocused. +func (n *ginkgoNode) BackpropagateUnfocus() { + focusedSpecInSubtreeStack := []bool{} + n.PostOrder(func(thisNode *ginkgoNode) { + if thisNode.Spec { + focusedSpecInSubtreeStack = append(focusedSpecInSubtreeStack, thisNode.Focused) + return + } + focusedSpecInSubtree := false + for range thisNode.Nodes { + focusedSpecInSubtree = focusedSpecInSubtree || focusedSpecInSubtreeStack[len(focusedSpecInSubtreeStack)-1] + focusedSpecInSubtreeStack = focusedSpecInSubtreeStack[0 : len(focusedSpecInSubtreeStack)-1] + } + focusedSpecInSubtreeStack = append(focusedSpecInSubtreeStack, focusedSpecInSubtree) + if focusedSpecInSubtree { + thisNode.Focused = false + } + }) + +} + +func packageAndIdentNamesFromCallExpr(ce *ast.CallExpr) (string, string, bool) { + switch ex := ce.Fun.(type) { + case *ast.Ident: + return "", ex.Name, true + case *ast.SelectorExpr: + pkgID, ok := ex.X.(*ast.Ident) + if !ok { + return "", "", false + } + // A package identifier is top-level, so Obj must be nil + if pkgID.Obj != nil { + return "", "", false + } + if ex.Sel == nil { + return "", "", false + } + return pkgID.Name, ex.Sel.Name, true + default: + return "", "", false + } +} + +// absoluteOffsetsForNode derives the absolute character offsets of the node start and +// end positions. +func absoluteOffsetsForNode(fset *token.FileSet, n ast.Node) (start, end int) { + return fset.PositionFor(n.Pos(), false).Offset, fset.PositionFor(n.End(), false).Offset +} + +// ginkgoNodeFromCallExpr derives an outline entry from a go AST subtree +// corresponding to a Ginkgo container or spec. +func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackageName, tablePackageName *string) (*ginkgoNode, bool) { + packageName, identName, ok := packageAndIdentNamesFromCallExpr(ce) + if !ok { + return nil, false + } + + n := ginkgoNode{} + n.Name = identName + n.Start, n.End = absoluteOffsetsForNode(fset, ce) + n.Nodes = make([]*ginkgoNode, 0) + switch identName { + case "It", "Measure", "Specify": + n.Spec = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "Entry": + n.Spec = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, tablePackageName != nil && *tablePackageName == packageName + case "FIt", "FMeasure", "FSpecify": + n.Spec = true + n.Focused = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "FEntry": + n.Spec = true + n.Focused = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, tablePackageName != nil && *tablePackageName == packageName + case "PIt", "PMeasure", "PSpecify", "XIt", "XMeasure", "XSpecify": + n.Spec = true + n.Pending = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "PEntry", "XEntry": + n.Spec = true + n.Pending = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, tablePackageName != nil && *tablePackageName == packageName + case "Context", "Describe", "When": + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "DescribeTable": + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, tablePackageName != nil && *tablePackageName == packageName + case "FContext", "FDescribe", "FWhen": + n.Focused = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "FDescribeTable": + n.Focused = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, tablePackageName != nil && *tablePackageName == packageName + case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen": + n.Pending = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "PDescribeTable", "XDescribeTable": + n.Pending = true + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, tablePackageName != nil && *tablePackageName == packageName + case "By": + n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "AfterEach", "BeforeEach": + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "JustAfterEach", "JustBeforeEach": + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "AfterSuite", "BeforeSuite": + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + case "SynchronizedAfterSuite", "SynchronizedBeforeSuite": + return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName + default: + return nil, false + } +} + +// textOrAltFromCallExpr tries to derive the "text" of a Ginkgo spec or +// container. If it cannot derive it, it returns the alt text. +func textOrAltFromCallExpr(ce *ast.CallExpr, alt string) string { + text, defined := textFromCallExpr(ce) + if !defined { + return alt + } + return text +} + +// textFromCallExpr tries to derive the "text" of a Ginkgo spec or container. If +// it cannot derive it, it returns false. +func textFromCallExpr(ce *ast.CallExpr) (string, bool) { + if len(ce.Args) < 1 { + return "", false + } + text, ok := ce.Args[0].(*ast.BasicLit) + if !ok { + return "", false + } + switch text.Kind { + case token.CHAR, token.STRING: + // For token.CHAR and token.STRING, Value is quoted + unquoted, err := strconv.Unquote(text.Value) + if err != nil { + // If unquoting fails, just use the raw Value + return text.Value, true + } + return unquoted, true + default: + return text.Value, true + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/outline/import.go b/vendor/github.com/onsi/ginkgo/ginkgo/outline/import.go new file mode 100644 index 00000000..4328ab39 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/outline/import.go @@ -0,0 +1,65 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Most of the required functions were available in the +// "golang.org/x/tools/go/ast/astutil" package, but not exported. +// They were copied from https://github.com/golang/tools/blob/2b0845dc783e36ae26d683f4915a5840ef01ab0f/go/ast/astutil/imports.go + +package outline + +import ( + "go/ast" + "strconv" + "strings" +) + +// packageNameForImport returns the package name for the package. If the package +// is not imported, it returns nil. "Package name" refers to `pkgname` in the +// call expression `pkgname.ExportedIdentifier`. Examples: +// (import path not found) -> nil +// "import example.com/pkg/foo" -> "foo" +// "import fooalias example.com/pkg/foo" -> "fooalias" +// "import . example.com/pkg/foo" -> "" +func packageNameForImport(f *ast.File, path string) *string { + spec := importSpec(f, path) + if spec == nil { + return nil + } + name := spec.Name.String() + if name == "" { + // If the package name is not explicitly specified, + // make an educated guess. This is not guaranteed to be correct. + lastSlash := strings.LastIndex(path, "/") + if lastSlash == -1 { + name = path + } else { + name = path[lastSlash+1:] + } + } + if name == "." { + name = "" + } + return &name +} + +// importSpec returns the import spec if f imports path, +// or nil otherwise. +func importSpec(f *ast.File, path string) *ast.ImportSpec { + for _, s := range f.Imports { + if importPath(s) == path { + return s + } + } + return nil +} + +// importPath returns the unquoted import path of s, +// or "" if the path is not properly quoted. +func importPath(s *ast.ImportSpec) string { + t, err := strconv.Unquote(s.Path.Value) + if err != nil { + return "" + } + return t +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/outline/outline.go b/vendor/github.com/onsi/ginkgo/ginkgo/outline/outline.go new file mode 100644 index 00000000..242e6a10 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/outline/outline.go @@ -0,0 +1,107 @@ +package outline + +import ( + "encoding/json" + "fmt" + "go/ast" + "go/token" + "strings" + + "golang.org/x/tools/go/ast/inspector" +) + +const ( + // ginkgoImportPath is the well-known ginkgo import path + ginkgoImportPath = "github.com/onsi/ginkgo" + + // tableImportPath is the well-known table extension import path + tableImportPath = "github.com/onsi/ginkgo/extensions/table" +) + +// FromASTFile returns an outline for a Ginkgo test source file +func FromASTFile(fset *token.FileSet, src *ast.File) (*outline, error) { + ginkgoPackageName := packageNameForImport(src, ginkgoImportPath) + tablePackageName := packageNameForImport(src, tableImportPath) + if ginkgoPackageName == nil && tablePackageName == nil { + return nil, fmt.Errorf("file does not import %q or %q", ginkgoImportPath, tableImportPath) + } + + root := ginkgoNode{} + stack := []*ginkgoNode{&root} + ispr := inspector.New([]*ast.File{src}) + ispr.Nodes([]ast.Node{(*ast.CallExpr)(nil)}, func(node ast.Node, push bool) bool { + if push { + // Pre-order traversal + ce, ok := node.(*ast.CallExpr) + if !ok { + // Because `Nodes` calls this function only when the node is an + // ast.CallExpr, this should never happen + panic(fmt.Errorf("node starting at %d, ending at %d is not an *ast.CallExpr", node.Pos(), node.End())) + } + gn, ok := ginkgoNodeFromCallExpr(fset, ce, ginkgoPackageName, tablePackageName) + if !ok { + // Node is not a Ginkgo spec or container, continue + return true + } + parent := stack[len(stack)-1] + parent.Nodes = append(parent.Nodes, gn) + stack = append(stack, gn) + return true + } + // Post-order traversal + start, end := absoluteOffsetsForNode(fset, node) + lastVisitedGinkgoNode := stack[len(stack)-1] + if start != lastVisitedGinkgoNode.Start || end != lastVisitedGinkgoNode.End { + // Node is not a Ginkgo spec or container, so it was not pushed onto the stack, continue + return true + } + stack = stack[0 : len(stack)-1] + return true + }) + if len(root.Nodes) == 0 { + return &outline{[]*ginkgoNode{}}, nil + } + + // Derive the final focused property for all nodes. This must be done + // _before_ propagating the inherited focused property. + root.BackpropagateUnfocus() + // Now, propagate inherited properties, including focused and pending. + root.PropagateInheritedProperties() + + return &outline{root.Nodes}, nil +} + +type outline struct { + Nodes []*ginkgoNode `json:"nodes"` +} + +func (o *outline) MarshalJSON() ([]byte, error) { + return json.Marshal(o.Nodes) +} + +// String returns a CSV-formatted outline. Spec or container are output in +// depth-first order. +func (o *outline) String() string { + return o.StringIndent(0) +} + +// StringIndent returns a CSV-formated outline, but every line is indented by +// one 'width' of spaces for every level of nesting. +func (o *outline) StringIndent(width int) string { + var b strings.Builder + b.WriteString("Name,Text,Start,End,Spec,Focused,Pending\n") + + currentIndent := 0 + pre := func(n *ginkgoNode) { + b.WriteString(fmt.Sprintf("%*s", currentIndent, "")) + b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending)) + currentIndent += width + } + post := func(n *ginkgoNode) { + currentIndent -= width + } + for _, n := range o.Nodes { + n.Walk(pre, post) + } + return b.String() +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/outline_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/outline_command.go new file mode 100644 index 00000000..96ca7ad2 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/outline_command.go @@ -0,0 +1,95 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "go/parser" + "go/token" + "os" + + "github.com/onsi/ginkgo/ginkgo/outline" +) + +const ( + // indentWidth is the width used by the 'indent' output + indentWidth = 4 + // stdinAlias is a portable alias for stdin. This convention is used in + // other CLIs, e.g., kubectl. + stdinAlias = "-" + usageCommand = "ginkgo outline " +) + +func BuildOutlineCommand() *Command { + const defaultFormat = "csv" + var format string + flagSet := flag.NewFlagSet("outline", flag.ExitOnError) + flagSet.StringVar(&format, "format", defaultFormat, "Format of outline. Accepted: 'csv', 'indent', 'json'") + return &Command{ + Name: "outline", + FlagSet: flagSet, + UsageCommand: usageCommand, + Usage: []string{ + "Create an outline of Ginkgo symbols for a file", + "To read from stdin, use: `ginkgo outline -`", + "Accepts the following flags:", + }, + Command: func(args []string, additionalArgs []string) { + outlineFile(args, format) + }, + } +} + +func outlineFile(args []string, format string) { + if len(args) != 1 { + println(fmt.Sprintf("usage: %s", usageCommand)) + os.Exit(1) + } + + filename := args[0] + var src *os.File + if filename == stdinAlias { + src = os.Stdin + } else { + var err error + src, err = os.Open(filename) + if err != nil { + println(fmt.Sprintf("error opening file: %s", err)) + os.Exit(1) + } + } + + fset := token.NewFileSet() + + parsedSrc, err := parser.ParseFile(fset, filename, src, 0) + if err != nil { + println(fmt.Sprintf("error parsing source: %s", err)) + os.Exit(1) + } + + o, err := outline.FromASTFile(fset, parsedSrc) + if err != nil { + println(fmt.Sprintf("error creating outline: %s", err)) + os.Exit(1) + } + + var oerr error + switch format { + case "csv": + _, oerr = fmt.Print(o) + case "indent": + _, oerr = fmt.Print(o.StringIndent(indentWidth)) + case "json": + b, err := json.Marshal(o) + if err != nil { + println(fmt.Sprintf("error marshalling to json: %s", err)) + } + _, oerr = fmt.Println(string(b)) + default: + complainAndQuit(fmt.Sprintf("format %s not accepted", format)) + } + if oerr != nil { + println(fmt.Sprintf("error writing outline: %s", oerr)) + os.Exit(1) + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/run_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/run_command.go new file mode 100644 index 00000000..c7f80d14 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/run_command.go @@ -0,0 +1,315 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "os" + "regexp" + "runtime" + "strings" + "time" + + "io/ioutil" + "path/filepath" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/ginkgo/interrupthandler" + "github.com/onsi/ginkgo/ginkgo/testrunner" + colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" + "github.com/onsi/ginkgo/types" +) + +func BuildRunCommand() *Command { + commandFlags := NewRunCommandFlags(flag.NewFlagSet("ginkgo", flag.ExitOnError)) + notifier := NewNotifier(commandFlags) + interruptHandler := interrupthandler.NewInterruptHandler() + runner := &SpecRunner{ + commandFlags: commandFlags, + notifier: notifier, + interruptHandler: interruptHandler, + suiteRunner: NewSuiteRunner(notifier, interruptHandler), + } + + return &Command{ + Name: "", + FlagSet: commandFlags.FlagSet, + UsageCommand: "ginkgo -- ", + Usage: []string{ + "Run the tests in the passed in (or the package in the current directory if left blank).", + "Any arguments after -- will be passed to the test.", + "Accepts the following flags:", + }, + Command: runner.RunSpecs, + } +} + +type SpecRunner struct { + commandFlags *RunWatchAndBuildCommandFlags + notifier *Notifier + interruptHandler *interrupthandler.InterruptHandler + suiteRunner *SuiteRunner +} + +func (r *SpecRunner) RunSpecs(args []string, additionalArgs []string) { + r.commandFlags.computeNodes() + r.notifier.VerifyNotificationsAreAvailable() + + deprecationTracker := types.NewDeprecationTracker() + + if r.commandFlags.ParallelStream && (runtime.GOOS != "windows") { + deprecationTracker.TrackDeprecation(types.Deprecation{ + Message: "--stream is deprecated and will be removed in Ginkgo 2.0", + DocLink: "removed--stream", + Version: "1.16.0", + }) + } + + if r.commandFlags.Notify { + deprecationTracker.TrackDeprecation(types.Deprecation{ + Message: "--notify is deprecated and will be removed in Ginkgo 2.0", + DocLink: "removed--notify", + Version: "1.16.0", + }) + } + + if deprecationTracker.DidTrackDeprecations() { + fmt.Fprintln(colorable.NewColorableStderr(), deprecationTracker.DeprecationsReport()) + } + + suites, skippedPackages := findSuites(args, r.commandFlags.Recurse, r.commandFlags.SkipPackage, true) + if len(skippedPackages) > 0 { + fmt.Println("Will skip:") + for _, skippedPackage := range skippedPackages { + fmt.Println(" " + skippedPackage) + } + } + + if len(skippedPackages) > 0 && len(suites) == 0 { + fmt.Println("All tests skipped! Exiting...") + os.Exit(0) + } + + if len(suites) == 0 { + complainAndQuit("Found no test suites") + } + + r.ComputeSuccinctMode(len(suites)) + + t := time.Now() + + runners := []*testrunner.TestRunner{} + for _, suite := range suites { + runners = append(runners, testrunner.New(suite, r.commandFlags.NumCPU, r.commandFlags.ParallelStream, r.commandFlags.Timeout, r.commandFlags.GoOpts, additionalArgs)) + } + + numSuites := 0 + runResult := testrunner.PassingRunResult() + if r.commandFlags.UntilItFails { + iteration := 0 + for { + r.UpdateSeed() + randomizedRunners := r.randomizeOrder(runners) + runResult, numSuites = r.suiteRunner.RunSuites(randomizedRunners, r.commandFlags.NumCompilers, r.commandFlags.KeepGoing, nil) + iteration++ + + if r.interruptHandler.WasInterrupted() { + break + } + + if runResult.Passed { + fmt.Printf("\nAll tests passed...\nWill keep running them until they fail.\nThis was attempt #%d\n%s\n", iteration, orcMessage(iteration)) + } else { + fmt.Printf("\nTests failed on attempt #%d\n\n", iteration) + break + } + } + } else { + randomizedRunners := r.randomizeOrder(runners) + runResult, numSuites = r.suiteRunner.RunSuites(randomizedRunners, r.commandFlags.NumCompilers, r.commandFlags.KeepGoing, nil) + } + + for _, runner := range runners { + runner.CleanUp() + } + + if r.isInCoverageMode() { + if r.getOutputDir() != "" { + // If coverprofile is set, combine coverages + if r.getCoverprofile() != "" { + if err := r.combineCoverprofiles(runners); err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } + } else { + // Just move them + r.moveCoverprofiles(runners) + } + } + } + + fmt.Printf("\nGinkgo ran %d %s in %s\n", numSuites, pluralizedWord("suite", "suites", numSuites), time.Since(t)) + + if runResult.Passed { + if runResult.HasProgrammaticFocus && strings.TrimSpace(os.Getenv("GINKGO_EDITOR_INTEGRATION")) == "" { + fmt.Printf("Test Suite Passed\n") + fmt.Printf("Detected Programmatic Focus - setting exit status to %d\n", types.GINKGO_FOCUS_EXIT_CODE) + os.Exit(types.GINKGO_FOCUS_EXIT_CODE) + } else { + fmt.Printf("Test Suite Passed\n") + os.Exit(0) + } + } else { + fmt.Printf("Test Suite Failed\n") + os.Exit(1) + } +} + +// Moves all generated profiles to specified directory +func (r *SpecRunner) moveCoverprofiles(runners []*testrunner.TestRunner) { + for _, runner := range runners { + _, filename := filepath.Split(runner.CoverageFile) + err := os.Rename(runner.CoverageFile, filepath.Join(r.getOutputDir(), filename)) + + if err != nil { + fmt.Printf("Unable to move coverprofile %s, %v\n", runner.CoverageFile, err) + return + } + } +} + +// Combines all generated profiles in the specified directory +func (r *SpecRunner) combineCoverprofiles(runners []*testrunner.TestRunner) error { + + path, _ := filepath.Abs(r.getOutputDir()) + if !fileExists(path) { + return fmt.Errorf("Unable to create combined profile, outputdir does not exist: %s", r.getOutputDir()) + } + + fmt.Println("path is " + path) + + combined, err := os.OpenFile( + filepath.Join(path, r.getCoverprofile()), + os.O_WRONLY|os.O_CREATE, + 0666, + ) + + if err != nil { + fmt.Printf("Unable to create combined profile, %v\n", err) + return nil // non-fatal error + } + + modeRegex := regexp.MustCompile(`^mode: .*\n`) + for index, runner := range runners { + contents, err := ioutil.ReadFile(runner.CoverageFile) + + if err != nil { + fmt.Printf("Unable to read coverage file %s to combine, %v\n", runner.CoverageFile, err) + return nil // non-fatal error + } + + // remove the cover mode line from every file + // except the first one + if index > 0 { + contents = modeRegex.ReplaceAll(contents, []byte{}) + } + + _, err = combined.Write(contents) + + // Add a newline to the end of every file if missing. + if err == nil && len(contents) > 0 && contents[len(contents)-1] != '\n' { + _, err = combined.Write([]byte("\n")) + } + + if err != nil { + fmt.Printf("Unable to append to coverprofile, %v\n", err) + return nil // non-fatal error + } + } + + fmt.Println("All profiles combined") + return nil +} + +func (r *SpecRunner) isInCoverageMode() bool { + opts := r.commandFlags.GoOpts + return *opts["cover"].(*bool) || *opts["coverpkg"].(*string) != "" || *opts["covermode"].(*string) != "" +} + +func (r *SpecRunner) getCoverprofile() string { + return *r.commandFlags.GoOpts["coverprofile"].(*string) +} + +func (r *SpecRunner) getOutputDir() string { + return *r.commandFlags.GoOpts["outputdir"].(*string) +} + +func (r *SpecRunner) ComputeSuccinctMode(numSuites int) { + if config.DefaultReporterConfig.Verbose { + config.DefaultReporterConfig.Succinct = false + return + } + + if numSuites == 1 { + return + } + + if numSuites > 1 && !r.commandFlags.wasSet("succinct") { + config.DefaultReporterConfig.Succinct = true + } +} + +func (r *SpecRunner) UpdateSeed() { + if !r.commandFlags.wasSet("seed") { + config.GinkgoConfig.RandomSeed = time.Now().Unix() + } +} + +func (r *SpecRunner) randomizeOrder(runners []*testrunner.TestRunner) []*testrunner.TestRunner { + if !r.commandFlags.RandomizeSuites { + return runners + } + + if len(runners) <= 1 { + return runners + } + + randomizedRunners := make([]*testrunner.TestRunner, len(runners)) + randomizer := rand.New(rand.NewSource(config.GinkgoConfig.RandomSeed)) + permutation := randomizer.Perm(len(runners)) + for i, j := range permutation { + randomizedRunners[i] = runners[j] + } + return randomizedRunners +} + +func orcMessage(iteration int) string { + if iteration < 10 { + return "" + } else if iteration < 30 { + return []string{ + "If at first you succeed...", + "...try, try again.", + "Looking good!", + "Still good...", + "I think your tests are fine....", + "Yep, still passing", + "Oh boy, here I go testin' again!", + "Even the gophers are getting bored", + "Did you try -race?", + "Maybe you should stop now?", + "I'm getting tired...", + "What if I just made you a sandwich?", + "Hit ^C, hit ^C, please hit ^C", + "Make it stop. Please!", + "Come on! Enough is enough!", + "Dave, this conversation can serve no purpose anymore. Goodbye.", + "Just what do you think you're doing, Dave? ", + "I, Sisyphus", + "Insanity: doing the same thing over and over again and expecting different results. -Einstein", + "I guess Einstein never tried to churn butter", + }[iteration-10] + "\n" + } else { + return "No, seriously... you can probably stop now.\n" + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/run_watch_and_build_command_flags.go b/vendor/github.com/onsi/ginkgo/ginkgo/run_watch_and_build_command_flags.go new file mode 100644 index 00000000..e0994fc3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/run_watch_and_build_command_flags.go @@ -0,0 +1,169 @@ +package main + +import ( + "flag" + "runtime" + + "time" + + "github.com/onsi/ginkgo/config" +) + +type RunWatchAndBuildCommandFlags struct { + Recurse bool + SkipPackage string + GoOpts map[string]interface{} + + //for run and watch commands + NumCPU int + NumCompilers int + ParallelStream bool + Notify bool + AfterSuiteHook string + AutoNodes bool + Timeout time.Duration + + //only for run command + KeepGoing bool + UntilItFails bool + RandomizeSuites bool + + //only for watch command + Depth int + WatchRegExp string + + FlagSet *flag.FlagSet +} + +const runMode = 1 +const watchMode = 2 +const buildMode = 3 + +func NewRunCommandFlags(flagSet *flag.FlagSet) *RunWatchAndBuildCommandFlags { + c := &RunWatchAndBuildCommandFlags{ + FlagSet: flagSet, + } + c.flags(runMode) + return c +} + +func NewWatchCommandFlags(flagSet *flag.FlagSet) *RunWatchAndBuildCommandFlags { + c := &RunWatchAndBuildCommandFlags{ + FlagSet: flagSet, + } + c.flags(watchMode) + return c +} + +func NewBuildCommandFlags(flagSet *flag.FlagSet) *RunWatchAndBuildCommandFlags { + c := &RunWatchAndBuildCommandFlags{ + FlagSet: flagSet, + } + c.flags(buildMode) + return c +} + +func (c *RunWatchAndBuildCommandFlags) wasSet(flagName string) bool { + wasSet := false + c.FlagSet.Visit(func(f *flag.Flag) { + if f.Name == flagName { + wasSet = true + } + }) + + return wasSet +} + +func (c *RunWatchAndBuildCommandFlags) computeNodes() { + if c.wasSet("nodes") { + return + } + if c.AutoNodes { + switch n := runtime.NumCPU(); { + case n <= 4: + c.NumCPU = n + default: + c.NumCPU = n - 1 + } + } +} + +func (c *RunWatchAndBuildCommandFlags) stringSlot(slot string) *string { + var opt string + c.GoOpts[slot] = &opt + return &opt +} + +func (c *RunWatchAndBuildCommandFlags) boolSlot(slot string) *bool { + var opt bool + c.GoOpts[slot] = &opt + return &opt +} + +func (c *RunWatchAndBuildCommandFlags) intSlot(slot string) *int { + var opt int + c.GoOpts[slot] = &opt + return &opt +} + +func (c *RunWatchAndBuildCommandFlags) flags(mode int) { + c.GoOpts = make(map[string]interface{}) + + onWindows := (runtime.GOOS == "windows") + + c.FlagSet.BoolVar(&(c.Recurse), "r", false, "Find and run test suites under the current directory recursively.") + c.FlagSet.BoolVar(c.boolSlot("race"), "race", false, "Run tests with race detection enabled.") + c.FlagSet.BoolVar(c.boolSlot("cover"), "cover", false, "Run tests with coverage analysis, will generate coverage profiles with the package name in the current directory.") + c.FlagSet.StringVar(c.stringSlot("coverpkg"), "coverpkg", "", "Run tests with coverage on the given external modules.") + c.FlagSet.StringVar(&(c.SkipPackage), "skipPackage", "", "A comma-separated list of package names to be skipped. If any part of the package's path matches, that package is ignored.") + c.FlagSet.StringVar(c.stringSlot("tags"), "tags", "", "A list of build tags to consider satisfied during the build.") + c.FlagSet.StringVar(c.stringSlot("gcflags"), "gcflags", "", "Arguments to pass on each go tool compile invocation.") + c.FlagSet.StringVar(c.stringSlot("covermode"), "covermode", "", "Set the mode for coverage analysis.") + c.FlagSet.BoolVar(c.boolSlot("a"), "a", false, "Force rebuilding of packages that are already up-to-date.") + c.FlagSet.BoolVar(c.boolSlot("n"), "n", false, "Have `go test` print the commands but do not run them.") + c.FlagSet.BoolVar(c.boolSlot("msan"), "msan", false, "Enable interoperation with memory sanitizer.") + c.FlagSet.BoolVar(c.boolSlot("x"), "x", false, "Have `go test` print the commands.") + c.FlagSet.BoolVar(c.boolSlot("work"), "work", false, "Print the name of the temporary work directory and do not delete it when exiting.") + c.FlagSet.StringVar(c.stringSlot("asmflags"), "asmflags", "", "Arguments to pass on each go tool asm invocation.") + c.FlagSet.StringVar(c.stringSlot("buildmode"), "buildmode", "", "Build mode to use. See 'go help buildmode' for more.") + c.FlagSet.StringVar(c.stringSlot("mod"), "mod", "", "Go module control. See 'go help modules' for more.") + c.FlagSet.StringVar(c.stringSlot("compiler"), "compiler", "", "Name of compiler to use, as in runtime.Compiler (gccgo or gc).") + c.FlagSet.StringVar(c.stringSlot("gccgoflags"), "gccgoflags", "", "Arguments to pass on each gccgo compiler/linker invocation.") + c.FlagSet.StringVar(c.stringSlot("installsuffix"), "installsuffix", "", "A suffix to use in the name of the package installation directory.") + c.FlagSet.StringVar(c.stringSlot("ldflags"), "ldflags", "", "Arguments to pass on each go tool link invocation.") + c.FlagSet.BoolVar(c.boolSlot("linkshared"), "linkshared", false, "Link against shared libraries previously created with -buildmode=shared.") + c.FlagSet.StringVar(c.stringSlot("pkgdir"), "pkgdir", "", "install and load all packages from the given dir instead of the usual locations.") + c.FlagSet.StringVar(c.stringSlot("toolexec"), "toolexec", "", "a program to use to invoke toolchain programs like vet and asm.") + c.FlagSet.IntVar(c.intSlot("blockprofilerate"), "blockprofilerate", 1, "Control the detail provided in goroutine blocking profiles by calling runtime.SetBlockProfileRate with the given value.") + c.FlagSet.StringVar(c.stringSlot("coverprofile"), "coverprofile", "", "Write a coverage profile to the specified file after all tests have passed.") + c.FlagSet.StringVar(c.stringSlot("cpuprofile"), "cpuprofile", "", "Write a CPU profile to the specified file before exiting.") + c.FlagSet.StringVar(c.stringSlot("memprofile"), "memprofile", "", "Write a memory profile to the specified file after all tests have passed.") + c.FlagSet.IntVar(c.intSlot("memprofilerate"), "memprofilerate", 0, "Enable more precise (and expensive) memory profiles by setting runtime.MemProfileRate.") + c.FlagSet.StringVar(c.stringSlot("outputdir"), "outputdir", "", "Place output files from profiling in the specified directory.") + c.FlagSet.BoolVar(c.boolSlot("requireSuite"), "requireSuite", false, "Fail if there are ginkgo tests in a directory but no test suite (missing RunSpecs)") + c.FlagSet.StringVar(c.stringSlot("vet"), "vet", "", "Configure the invocation of 'go vet' to use the comma-separated list of vet checks. If list is 'off', 'go test' does not run 'go vet' at all.") + + if mode == runMode || mode == watchMode { + config.Flags(c.FlagSet, "", false) + c.FlagSet.IntVar(&(c.NumCPU), "nodes", 1, "The number of parallel test nodes to run") + c.FlagSet.IntVar(&(c.NumCompilers), "compilers", 0, "The number of concurrent compilations to run (0 will autodetect)") + c.FlagSet.BoolVar(&(c.AutoNodes), "p", false, "Run in parallel with auto-detected number of nodes") + c.FlagSet.BoolVar(&(c.ParallelStream), "stream", onWindows, "stream parallel test output in real time: less coherent, but useful for debugging") + if !onWindows { + c.FlagSet.BoolVar(&(c.Notify), "notify", false, "Send desktop notifications when a test run completes") + } + c.FlagSet.StringVar(&(c.AfterSuiteHook), "afterSuiteHook", "", "Run a command when a suite test run completes") + c.FlagSet.DurationVar(&(c.Timeout), "timeout", 24*time.Hour, "Suite fails if it does not complete within the specified timeout") + } + + if mode == runMode { + c.FlagSet.BoolVar(&(c.KeepGoing), "keepGoing", false, "When true, failures from earlier test suites do not prevent later test suites from running") + c.FlagSet.BoolVar(&(c.UntilItFails), "untilItFails", false, "When true, Ginkgo will keep rerunning tests until a failure occurs") + c.FlagSet.BoolVar(&(c.RandomizeSuites), "randomizeSuites", false, "When true, Ginkgo will randomize the order in which test suites run") + } + + if mode == watchMode { + c.FlagSet.IntVar(&(c.Depth), "depth", 1, "Ginkgo will watch dependencies down to this depth in the dependency tree") + c.FlagSet.StringVar(&(c.WatchRegExp), "watchRegExp", `\.go$`, "Files matching this regular expression will be watched for changes") + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/suite_runner.go b/vendor/github.com/onsi/ginkgo/ginkgo/suite_runner.go new file mode 100644 index 00000000..ab746d7e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/suite_runner.go @@ -0,0 +1,173 @@ +package main + +import ( + "fmt" + "runtime" + "sync" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/ginkgo/interrupthandler" + "github.com/onsi/ginkgo/ginkgo/testrunner" + "github.com/onsi/ginkgo/ginkgo/testsuite" + colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" +) + +type compilationInput struct { + runner *testrunner.TestRunner + result chan compilationOutput +} + +type compilationOutput struct { + runner *testrunner.TestRunner + err error +} + +type SuiteRunner struct { + notifier *Notifier + interruptHandler *interrupthandler.InterruptHandler +} + +func NewSuiteRunner(notifier *Notifier, interruptHandler *interrupthandler.InterruptHandler) *SuiteRunner { + return &SuiteRunner{ + notifier: notifier, + interruptHandler: interruptHandler, + } +} + +func (r *SuiteRunner) compileInParallel(runners []*testrunner.TestRunner, numCompilers int, willCompile func(suite testsuite.TestSuite)) chan compilationOutput { + //we return this to the consumer, it will return each runner in order as it compiles + compilationOutputs := make(chan compilationOutput, len(runners)) + + //an array of channels - the nth runner's compilation output is sent to the nth channel in this array + //we read from these channels in order to ensure we run the suites in order + orderedCompilationOutputs := []chan compilationOutput{} + for range runners { + orderedCompilationOutputs = append(orderedCompilationOutputs, make(chan compilationOutput, 1)) + } + + //we're going to spin up numCompilers compilers - they're going to run concurrently and will consume this channel + //we prefill the channel then close it, this ensures we compile things in the correct order + workPool := make(chan compilationInput, len(runners)) + for i, runner := range runners { + workPool <- compilationInput{runner, orderedCompilationOutputs[i]} + } + close(workPool) + + //pick a reasonable numCompilers + if numCompilers == 0 { + numCompilers = runtime.NumCPU() + } + + //a WaitGroup to help us wait for all compilers to shut down + wg := &sync.WaitGroup{} + wg.Add(numCompilers) + + //spin up the concurrent compilers + for i := 0; i < numCompilers; i++ { + go func() { + defer wg.Done() + for input := range workPool { + if r.interruptHandler.WasInterrupted() { + return + } + + if willCompile != nil { + willCompile(input.runner.Suite) + } + + //We retry because Go sometimes steps on itself when multiple compiles happen in parallel. This is ugly, but should help resolve flakiness... + var err error + retries := 0 + for retries <= 5 { + if r.interruptHandler.WasInterrupted() { + return + } + if err = input.runner.Compile(); err == nil { + break + } + retries++ + } + + input.result <- compilationOutput{input.runner, err} + } + }() + } + + //read from the compilation output channels *in order* and send them to the caller + //close the compilationOutputs channel to tell the caller we're done + go func() { + defer close(compilationOutputs) + for _, orderedCompilationOutput := range orderedCompilationOutputs { + select { + case compilationOutput := <-orderedCompilationOutput: + compilationOutputs <- compilationOutput + case <-r.interruptHandler.C: + //interrupt detected, wait for the compilers to shut down then bail + //this ensure we clean up after ourselves as we don't leave any compilation processes running + wg.Wait() + return + } + } + }() + + return compilationOutputs +} + +func (r *SuiteRunner) RunSuites(runners []*testrunner.TestRunner, numCompilers int, keepGoing bool, willCompile func(suite testsuite.TestSuite)) (testrunner.RunResult, int) { + runResult := testrunner.PassingRunResult() + + compilationOutputs := r.compileInParallel(runners, numCompilers, willCompile) + + numSuitesThatRan := 0 + suitesThatFailed := []testsuite.TestSuite{} + for compilationOutput := range compilationOutputs { + if compilationOutput.err != nil { + fmt.Print(compilationOutput.err.Error()) + } + numSuitesThatRan++ + suiteRunResult := testrunner.FailingRunResult() + if compilationOutput.err == nil { + suiteRunResult = compilationOutput.runner.Run() + } + r.notifier.SendSuiteCompletionNotification(compilationOutput.runner.Suite, suiteRunResult.Passed) + r.notifier.RunCommand(compilationOutput.runner.Suite, suiteRunResult.Passed) + runResult = runResult.Merge(suiteRunResult) + if !suiteRunResult.Passed { + suitesThatFailed = append(suitesThatFailed, compilationOutput.runner.Suite) + if !keepGoing { + break + } + } + if numSuitesThatRan < len(runners) && !config.DefaultReporterConfig.Succinct { + fmt.Println("") + } + } + + if keepGoing && !runResult.Passed { + r.listFailedSuites(suitesThatFailed) + } + + return runResult, numSuitesThatRan +} + +func (r *SuiteRunner) listFailedSuites(suitesThatFailed []testsuite.TestSuite) { + fmt.Println("") + fmt.Println("There were failures detected in the following suites:") + + maxPackageNameLength := 0 + for _, suite := range suitesThatFailed { + if len(suite.PackageName) > maxPackageNameLength { + maxPackageNameLength = len(suite.PackageName) + } + } + + packageNameFormatter := fmt.Sprintf("%%%ds", maxPackageNameLength) + + for _, suite := range suitesThatFailed { + if config.DefaultReporterConfig.NoColor { + fmt.Printf("\t"+packageNameFormatter+" %s\n", suite.PackageName, suite.Path) + } else { + fmt.Fprintf(colorable.NewColorableStdout(), "\t%s"+packageNameFormatter+"%s %s%s%s\n", redColor, suite.PackageName, defaultStyle, lightGrayColor, suite.Path, defaultStyle) + } + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args.go new file mode 100644 index 00000000..3b1a238c --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args.go @@ -0,0 +1,7 @@ +// +build go1.10 + +package testrunner + +var ( + buildArgs = []string{"test", "-c"} +) diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args_old.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args_old.go new file mode 100644 index 00000000..14d70dbc --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args_old.go @@ -0,0 +1,7 @@ +// +build !go1.10 + +package testrunner + +var ( + buildArgs = []string{"test", "-c", "-i"} +) diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/log_writer.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/log_writer.go new file mode 100644 index 00000000..a73a6e37 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/log_writer.go @@ -0,0 +1,52 @@ +package testrunner + +import ( + "bytes" + "fmt" + "io" + "log" + "strings" + "sync" +) + +type logWriter struct { + buffer *bytes.Buffer + lock *sync.Mutex + log *log.Logger +} + +func newLogWriter(target io.Writer, node int) *logWriter { + return &logWriter{ + buffer: &bytes.Buffer{}, + lock: &sync.Mutex{}, + log: log.New(target, fmt.Sprintf("[%d] ", node), 0), + } +} + +func (w *logWriter) Write(data []byte) (n int, err error) { + w.lock.Lock() + defer w.lock.Unlock() + + w.buffer.Write(data) + contents := w.buffer.String() + + lines := strings.Split(contents, "\n") + for _, line := range lines[0 : len(lines)-1] { + w.log.Println(line) + } + + w.buffer.Reset() + w.buffer.Write([]byte(lines[len(lines)-1])) + return len(data), nil +} + +func (w *logWriter) Close() error { + w.lock.Lock() + defer w.lock.Unlock() + + if w.buffer.Len() > 0 { + w.log.Println(w.buffer.String()) + } + + return nil +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/run_result.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/run_result.go new file mode 100644 index 00000000..5d472acb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/run_result.go @@ -0,0 +1,27 @@ +package testrunner + +type RunResult struct { + Passed bool + HasProgrammaticFocus bool +} + +func PassingRunResult() RunResult { + return RunResult{ + Passed: true, + HasProgrammaticFocus: false, + } +} + +func FailingRunResult() RunResult { + return RunResult{ + Passed: false, + HasProgrammaticFocus: false, + } +} + +func (r RunResult) Merge(o RunResult) RunResult { + return RunResult{ + Passed: r.Passed && o.Passed, + HasProgrammaticFocus: r.HasProgrammaticFocus || o.HasProgrammaticFocus, + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/test_runner.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/test_runner.go new file mode 100644 index 00000000..66c0f06f --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/test_runner.go @@ -0,0 +1,554 @@ +package testrunner + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "syscall" + "time" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/ginkgo/testsuite" + "github.com/onsi/ginkgo/internal/remote" + "github.com/onsi/ginkgo/reporters/stenographer" + colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" + "github.com/onsi/ginkgo/types" +) + +type TestRunner struct { + Suite testsuite.TestSuite + + compiled bool + compilationTargetPath string + + numCPU int + parallelStream bool + timeout time.Duration + goOpts map[string]interface{} + additionalArgs []string + stderr *bytes.Buffer + + CoverageFile string +} + +func New(suite testsuite.TestSuite, numCPU int, parallelStream bool, timeout time.Duration, goOpts map[string]interface{}, additionalArgs []string) *TestRunner { + runner := &TestRunner{ + Suite: suite, + numCPU: numCPU, + parallelStream: parallelStream, + goOpts: goOpts, + additionalArgs: additionalArgs, + timeout: timeout, + stderr: new(bytes.Buffer), + } + + if !suite.Precompiled { + runner.compilationTargetPath, _ = filepath.Abs(filepath.Join(suite.Path, suite.PackageName+".test")) + } + + return runner +} + +func (t *TestRunner) Compile() error { + return t.CompileTo(t.compilationTargetPath) +} + +func (t *TestRunner) BuildArgs(path string) []string { + args := make([]string, len(buildArgs), len(buildArgs)+3) + copy(args, buildArgs) + args = append(args, "-o", path, t.Suite.Path) + + if t.getCoverMode() != "" { + args = append(args, "-cover", fmt.Sprintf("-covermode=%s", t.getCoverMode())) + } else { + if t.shouldCover() || t.getCoverPackage() != "" { + args = append(args, "-cover", "-covermode=atomic") + } + } + + boolOpts := []string{ + "a", + "n", + "msan", + "race", + "x", + "work", + "linkshared", + } + + for _, opt := range boolOpts { + if s, found := t.goOpts[opt].(*bool); found && *s { + args = append(args, fmt.Sprintf("-%s", opt)) + } + } + + intOpts := []string{ + "memprofilerate", + "blockprofilerate", + } + + for _, opt := range intOpts { + if s, found := t.goOpts[opt].(*int); found { + args = append(args, fmt.Sprintf("-%s=%d", opt, *s)) + } + } + + stringOpts := []string{ + "asmflags", + "buildmode", + "compiler", + "gccgoflags", + "installsuffix", + "ldflags", + "pkgdir", + "toolexec", + "coverprofile", + "cpuprofile", + "memprofile", + "outputdir", + "coverpkg", + "tags", + "gcflags", + "vet", + "mod", + } + + for _, opt := range stringOpts { + if s, found := t.goOpts[opt].(*string); found && *s != "" { + args = append(args, fmt.Sprintf("-%s=%s", opt, *s)) + } + } + return args +} + +func (t *TestRunner) CompileTo(path string) error { + if t.compiled { + return nil + } + + if t.Suite.Precompiled { + return nil + } + + args := t.BuildArgs(path) + cmd := exec.Command("go", args...) + + output, err := cmd.CombinedOutput() + + if err != nil { + if len(output) > 0 { + return fmt.Errorf("Failed to compile %s:\n\n%s", t.Suite.PackageName, output) + } + return fmt.Errorf("Failed to compile %s", t.Suite.PackageName) + } + + if len(output) > 0 { + fmt.Println(string(output)) + } + + if !fileExists(path) { + compiledFile := t.Suite.PackageName + ".test" + if fileExists(compiledFile) { + // seems like we are on an old go version that does not support the -o flag on go test + // move the compiled test file to the desired location by hand + err = os.Rename(compiledFile, path) + if err != nil { + // We cannot move the file, perhaps because the source and destination + // are on different partitions. We can copy the file, however. + err = copyFile(compiledFile, path) + if err != nil { + return fmt.Errorf("Failed to copy compiled file: %s", err) + } + } + } else { + return fmt.Errorf("Failed to compile %s: output file %q could not be found", t.Suite.PackageName, path) + } + } + + t.compiled = true + + return nil +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil || !os.IsNotExist(err) +} + +// copyFile copies the contents of the file named src to the file named +// by dst. The file will be created if it does not already exist. If the +// destination file exists, all it's contents will be replaced by the contents +// of the source file. +func copyFile(src, dst string) error { + srcInfo, err := os.Stat(src) + if err != nil { + return err + } + mode := srcInfo.Mode() + + in, err := os.Open(src) + if err != nil { + return err + } + + defer in.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + + defer func() { + closeErr := out.Close() + if err == nil { + err = closeErr + } + }() + + _, err = io.Copy(out, in) + if err != nil { + return err + } + + err = out.Sync() + if err != nil { + return err + } + + return out.Chmod(mode) +} + +func (t *TestRunner) Run() RunResult { + if t.Suite.IsGinkgo { + if t.numCPU > 1 { + if t.parallelStream { + return t.runAndStreamParallelGinkgoSuite() + } else { + return t.runParallelGinkgoSuite() + } + } else { + return t.runSerialGinkgoSuite() + } + } else { + return t.runGoTestSuite() + } +} + +func (t *TestRunner) CleanUp() { + if t.Suite.Precompiled { + return + } + os.Remove(t.compilationTargetPath) +} + +func (t *TestRunner) runSerialGinkgoSuite() RunResult { + ginkgoArgs := config.BuildFlagArgs("ginkgo", config.GinkgoConfig, config.DefaultReporterConfig) + return t.run(t.cmd(ginkgoArgs, os.Stdout, 1), nil) +} + +func (t *TestRunner) runGoTestSuite() RunResult { + return t.run(t.cmd([]string{"-test.v"}, os.Stdout, 1), nil) +} + +func (t *TestRunner) runAndStreamParallelGinkgoSuite() RunResult { + completions := make(chan RunResult) + writers := make([]*logWriter, t.numCPU) + + server, err := remote.NewServer(t.numCPU) + if err != nil { + panic("Failed to start parallel spec server") + } + + server.Start() + defer server.Close() + + for cpu := 0; cpu < t.numCPU; cpu++ { + config.GinkgoConfig.ParallelNode = cpu + 1 + config.GinkgoConfig.ParallelTotal = t.numCPU + config.GinkgoConfig.SyncHost = server.Address() + + ginkgoArgs := config.BuildFlagArgs("ginkgo", config.GinkgoConfig, config.DefaultReporterConfig) + + writers[cpu] = newLogWriter(os.Stdout, cpu+1) + + cmd := t.cmd(ginkgoArgs, writers[cpu], cpu+1) + + server.RegisterAlive(cpu+1, func() bool { + if cmd.ProcessState == nil { + return true + } + return !cmd.ProcessState.Exited() + }) + + go t.run(cmd, completions) + } + + res := PassingRunResult() + + for cpu := 0; cpu < t.numCPU; cpu++ { + res = res.Merge(<-completions) + } + + for _, writer := range writers { + writer.Close() + } + + os.Stdout.Sync() + + if t.shouldCombineCoverprofiles() { + t.combineCoverprofiles() + } + + return res +} + +func (t *TestRunner) runParallelGinkgoSuite() RunResult { + result := make(chan bool) + completions := make(chan RunResult) + writers := make([]*logWriter, t.numCPU) + reports := make([]*bytes.Buffer, t.numCPU) + + stenographer := stenographer.New(!config.DefaultReporterConfig.NoColor, config.GinkgoConfig.FlakeAttempts > 1, colorable.NewColorableStdout()) + aggregator := remote.NewAggregator(t.numCPU, result, config.DefaultReporterConfig, stenographer) + + server, err := remote.NewServer(t.numCPU) + if err != nil { + panic("Failed to start parallel spec server") + } + server.RegisterReporters(aggregator) + server.Start() + defer server.Close() + + for cpu := 0; cpu < t.numCPU; cpu++ { + config.GinkgoConfig.ParallelNode = cpu + 1 + config.GinkgoConfig.ParallelTotal = t.numCPU + config.GinkgoConfig.SyncHost = server.Address() + config.GinkgoConfig.StreamHost = server.Address() + + ginkgoArgs := config.BuildFlagArgs("ginkgo", config.GinkgoConfig, config.DefaultReporterConfig) + + reports[cpu] = &bytes.Buffer{} + writers[cpu] = newLogWriter(reports[cpu], cpu+1) + + cmd := t.cmd(ginkgoArgs, writers[cpu], cpu+1) + + server.RegisterAlive(cpu+1, func() bool { + if cmd.ProcessState == nil { + return true + } + return !cmd.ProcessState.Exited() + }) + + go t.run(cmd, completions) + } + + res := PassingRunResult() + + for cpu := 0; cpu < t.numCPU; cpu++ { + res = res.Merge(<-completions) + } + + //all test processes are done, at this point + //we should be able to wait for the aggregator to tell us that it's done + + select { + case <-result: + fmt.Println("") + case <-time.After(time.Second): + //the aggregator never got back to us! something must have gone wrong + fmt.Println(` + ------------------------------------------------------------------- + | | + | Ginkgo timed out waiting for all parallel nodes to report back! | + | | + -------------------------------------------------------------------`) + fmt.Println("\n", t.Suite.PackageName, "timed out. path:", t.Suite.Path) + os.Stdout.Sync() + + for _, writer := range writers { + writer.Close() + } + + for _, report := range reports { + fmt.Print(report.String()) + } + + os.Stdout.Sync() + } + + if t.shouldCombineCoverprofiles() { + t.combineCoverprofiles() + } + + return res +} + +const CoverProfileSuffix = ".coverprofile" + +func (t *TestRunner) cmd(ginkgoArgs []string, stream io.Writer, node int) *exec.Cmd { + args := []string{"--test.timeout=" + t.timeout.String()} + + coverProfile := t.getCoverProfile() + + if t.shouldCombineCoverprofiles() { + + testCoverProfile := "--test.coverprofile=" + + coverageFile := "" + // Set default name for coverage results + if coverProfile == "" { + coverageFile = t.Suite.PackageName + CoverProfileSuffix + } else { + coverageFile = coverProfile + } + + testCoverProfile += coverageFile + + t.CoverageFile = filepath.Join(t.Suite.Path, coverageFile) + + if t.numCPU > 1 { + testCoverProfile = fmt.Sprintf("%s.%d", testCoverProfile, node) + } + args = append(args, testCoverProfile) + } + + args = append(args, ginkgoArgs...) + args = append(args, t.additionalArgs...) + + path := t.compilationTargetPath + if t.Suite.Precompiled { + path, _ = filepath.Abs(filepath.Join(t.Suite.Path, fmt.Sprintf("%s.test", t.Suite.PackageName))) + } + + cmd := exec.Command(path, args...) + + cmd.Dir = t.Suite.Path + cmd.Stderr = io.MultiWriter(stream, t.stderr) + cmd.Stdout = stream + + return cmd +} + +func (t *TestRunner) shouldCover() bool { + return *t.goOpts["cover"].(*bool) +} + +func (t *TestRunner) shouldRequireSuite() bool { + return *t.goOpts["requireSuite"].(*bool) +} + +func (t *TestRunner) getCoverProfile() string { + return *t.goOpts["coverprofile"].(*string) +} + +func (t *TestRunner) getCoverPackage() string { + return *t.goOpts["coverpkg"].(*string) +} + +func (t *TestRunner) getCoverMode() string { + return *t.goOpts["covermode"].(*string) +} + +func (t *TestRunner) shouldCombineCoverprofiles() bool { + return t.shouldCover() || t.getCoverPackage() != "" || t.getCoverMode() != "" +} + +func (t *TestRunner) run(cmd *exec.Cmd, completions chan RunResult) RunResult { + var res RunResult + + defer func() { + if completions != nil { + completions <- res + } + }() + + err := cmd.Start() + if err != nil { + fmt.Printf("Failed to run test suite!\n\t%s", err.Error()) + return res + } + + cmd.Wait() + + exitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus() + res.Passed = (exitStatus == 0) || (exitStatus == types.GINKGO_FOCUS_EXIT_CODE) + res.HasProgrammaticFocus = (exitStatus == types.GINKGO_FOCUS_EXIT_CODE) + + if strings.Contains(t.stderr.String(), "warning: no tests to run") { + if t.shouldRequireSuite() { + res.Passed = false + } + fmt.Fprintf(os.Stderr, `Found no test suites, did you forget to run "ginkgo bootstrap"?`) + } + + return res +} + +func (t *TestRunner) combineCoverprofiles() { + profiles := []string{} + + coverProfile := t.getCoverProfile() + + for cpu := 1; cpu <= t.numCPU; cpu++ { + var coverFile string + if coverProfile == "" { + coverFile = fmt.Sprintf("%s%s.%d", t.Suite.PackageName, CoverProfileSuffix, cpu) + } else { + coverFile = fmt.Sprintf("%s.%d", coverProfile, cpu) + } + + coverFile = filepath.Join(t.Suite.Path, coverFile) + coverProfile, err := ioutil.ReadFile(coverFile) + os.Remove(coverFile) + + if err == nil { + profiles = append(profiles, string(coverProfile)) + } + } + + if len(profiles) != t.numCPU { + return + } + + lines := map[string]int{} + lineOrder := []string{} + for i, coverProfile := range profiles { + for _, line := range strings.Split(coverProfile, "\n")[1:] { + if len(line) == 0 { + continue + } + components := strings.Split(line, " ") + count, _ := strconv.Atoi(components[len(components)-1]) + prefix := strings.Join(components[0:len(components)-1], " ") + lines[prefix] += count + if i == 0 { + lineOrder = append(lineOrder, prefix) + } + } + } + + output := []string{"mode: atomic"} + for _, line := range lineOrder { + output = append(output, fmt.Sprintf("%s %d", line, lines[line])) + } + finalOutput := strings.Join(output, "\n") + + finalFilename := "" + + if coverProfile != "" { + finalFilename = coverProfile + } else { + finalFilename = fmt.Sprintf("%s%s", t.Suite.PackageName, CoverProfileSuffix) + } + + coverageFilepath := filepath.Join(t.Suite.Path, finalFilename) + ioutil.WriteFile(coverageFilepath, []byte(finalOutput), 0666) + + t.CoverageFile = coverageFilepath +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/test_suite.go b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/test_suite.go new file mode 100644 index 00000000..9de8c2bb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/test_suite.go @@ -0,0 +1,115 @@ +package testsuite + +import ( + "errors" + "io/ioutil" + "os" + "path/filepath" + "regexp" + "strings" +) + +type TestSuite struct { + Path string + PackageName string + IsGinkgo bool + Precompiled bool +} + +func PrecompiledTestSuite(path string) (TestSuite, error) { + info, err := os.Stat(path) + if err != nil { + return TestSuite{}, err + } + + if info.IsDir() { + return TestSuite{}, errors.New("this is a directory, not a file") + } + + if filepath.Ext(path) != ".test" { + return TestSuite{}, errors.New("this is not a .test binary") + } + + if info.Mode()&0111 == 0 { + return TestSuite{}, errors.New("this is not executable") + } + + dir := relPath(filepath.Dir(path)) + packageName := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path)) + + return TestSuite{ + Path: dir, + PackageName: packageName, + IsGinkgo: true, + Precompiled: true, + }, nil +} + +func SuitesInDir(dir string, recurse bool) []TestSuite { + suites := []TestSuite{} + + if vendorExperimentCheck(dir) { + return suites + } + + files, _ := ioutil.ReadDir(dir) + re := regexp.MustCompile(`^[^._].*_test\.go$`) + for _, file := range files { + if !file.IsDir() && re.Match([]byte(file.Name())) { + suites = append(suites, New(dir, files)) + break + } + } + + if recurse { + re = regexp.MustCompile(`^[._]`) + for _, file := range files { + if file.IsDir() && !re.Match([]byte(file.Name())) { + suites = append(suites, SuitesInDir(dir+"/"+file.Name(), recurse)...) + } + } + } + + return suites +} + +func relPath(dir string) string { + dir, _ = filepath.Abs(dir) + cwd, _ := os.Getwd() + dir, _ = filepath.Rel(cwd, filepath.Clean(dir)) + + if string(dir[0]) != "." { + dir = "." + string(filepath.Separator) + dir + } + + return dir +} + +func New(dir string, files []os.FileInfo) TestSuite { + return TestSuite{ + Path: relPath(dir), + PackageName: packageNameForSuite(dir), + IsGinkgo: filesHaveGinkgoSuite(dir, files), + } +} + +func packageNameForSuite(dir string) string { + path, _ := filepath.Abs(dir) + return filepath.Base(path) +} + +func filesHaveGinkgoSuite(dir string, files []os.FileInfo) bool { + reTestFile := regexp.MustCompile(`_test\.go$`) + reGinkgo := regexp.MustCompile(`package ginkgo|\/ginkgo"`) + + for _, file := range files { + if !file.IsDir() && reTestFile.Match([]byte(file.Name())) { + contents, _ := ioutil.ReadFile(dir + "/" + file.Name()) + if reGinkgo.Match(contents) { + return true + } + } + } + + return false +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go15.go b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go15.go new file mode 100644 index 00000000..75f827a1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go15.go @@ -0,0 +1,16 @@ +// +build !go1.6 + +package testsuite + +import ( + "os" + "path" +) + +// "This change will only be enabled if the go command is run with +// GO15VENDOREXPERIMENT=1 in its environment." +// c.f. the vendor-experiment proposal https://goo.gl/2ucMeC +func vendorExperimentCheck(dir string) bool { + vendorExperiment := os.Getenv("GO15VENDOREXPERIMENT") + return vendorExperiment == "1" && path.Base(dir) == "vendor" +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go16.go b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go16.go new file mode 100644 index 00000000..596e5e5c --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go16.go @@ -0,0 +1,15 @@ +// +build go1.6 + +package testsuite + +import ( + "os" + "path" +) + +// in 1.6 the vendor directory became the default go behaviour, so now +// check if its disabled. +func vendorExperimentCheck(dir string) bool { + vendorExperiment := os.Getenv("GO15VENDOREXPERIMENT") + return vendorExperiment != "0" && path.Base(dir) == "vendor" +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/unfocus_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/unfocus_command.go new file mode 100644 index 00000000..d9dfb6e4 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/unfocus_command.go @@ -0,0 +1,180 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "sync" +) + +func BuildUnfocusCommand() *Command { + return &Command{ + Name: "unfocus", + AltName: "blur", + FlagSet: flag.NewFlagSet("unfocus", flag.ExitOnError), + UsageCommand: "ginkgo unfocus (or ginkgo blur)", + Usage: []string{ + "Recursively unfocuses any focused tests under the current directory", + }, + Command: unfocusSpecs, + } +} + +func unfocusSpecs([]string, []string) { + fmt.Println("Scanning for focus...") + + goFiles := make(chan string) + go func() { + unfocusDir(goFiles, ".") + close(goFiles) + }() + + const workers = 10 + wg := sync.WaitGroup{} + wg.Add(workers) + + for i := 0; i < workers; i++ { + go func() { + for path := range goFiles { + unfocusFile(path) + } + wg.Done() + }() + } + + wg.Wait() +} + +func unfocusDir(goFiles chan string, path string) { + files, err := ioutil.ReadDir(path) + if err != nil { + fmt.Println(err.Error()) + return + } + + for _, f := range files { + switch { + case f.IsDir() && shouldProcessDir(f.Name()): + unfocusDir(goFiles, filepath.Join(path, f.Name())) + case !f.IsDir() && shouldProcessFile(f.Name()): + goFiles <- filepath.Join(path, f.Name()) + } + } +} + +func shouldProcessDir(basename string) bool { + return basename != "vendor" && !strings.HasPrefix(basename, ".") +} + +func shouldProcessFile(basename string) bool { + return strings.HasSuffix(basename, ".go") +} + +func unfocusFile(path string) { + data, err := ioutil.ReadFile(path) + if err != nil { + fmt.Printf("error reading file '%s': %s\n", path, err.Error()) + return + } + + ast, err := parser.ParseFile(token.NewFileSet(), path, bytes.NewReader(data), 0) + if err != nil { + fmt.Printf("error parsing file '%s': %s\n", path, err.Error()) + return + } + + eliminations := scanForFocus(ast) + if len(eliminations) == 0 { + return + } + + fmt.Printf("...updating %s\n", path) + backup, err := writeBackup(path, data) + if err != nil { + fmt.Printf("error creating backup file: %s\n", err.Error()) + return + } + + if err := updateFile(path, data, eliminations); err != nil { + fmt.Printf("error writing file '%s': %s\n", path, err.Error()) + return + } + + os.Remove(backup) +} + +func writeBackup(path string, data []byte) (string, error) { + t, err := ioutil.TempFile(filepath.Dir(path), filepath.Base(path)) + + if err != nil { + return "", fmt.Errorf("error creating temporary file: %w", err) + } + defer t.Close() + + if _, err := io.Copy(t, bytes.NewReader(data)); err != nil { + return "", fmt.Errorf("error writing to temporary file: %w", err) + } + + return t.Name(), nil +} + +func updateFile(path string, data []byte, eliminations []int64) error { + to, err := os.Create(path) + if err != nil { + return fmt.Errorf("error opening file for writing '%s': %w\n", path, err) + } + defer to.Close() + + from := bytes.NewReader(data) + var cursor int64 + for _, byteToEliminate := range eliminations { + if _, err := io.CopyN(to, from, byteToEliminate-cursor); err != nil { + return fmt.Errorf("error copying data: %w", err) + } + + cursor = byteToEliminate + 1 + + if _, err := from.Seek(1, io.SeekCurrent); err != nil { + return fmt.Errorf("error seeking to position in buffer: %w", err) + } + } + + if _, err := io.Copy(to, from); err != nil { + return fmt.Errorf("error copying end data: %w", err) + } + + return nil +} + +func scanForFocus(file *ast.File) (eliminations []int64) { + ast.Inspect(file, func(n ast.Node) bool { + if c, ok := n.(*ast.CallExpr); ok { + if i, ok := c.Fun.(*ast.Ident); ok { + if isFocus(i.Name) { + eliminations = append(eliminations, int64(i.Pos()-file.Pos())) + } + } + } + + return true + }) + + return eliminations +} + +func isFocus(name string) bool { + switch name { + case "FDescribe", "FContext", "FIt", "FMeasure", "FDescribeTable", "FEntry", "FSpecify", "FWhen": + return true + default: + return false + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/version_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/version_command.go new file mode 100644 index 00000000..f586908e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/version_command.go @@ -0,0 +1,24 @@ +package main + +import ( + "flag" + "fmt" + + "github.com/onsi/ginkgo/config" +) + +func BuildVersionCommand() *Command { + return &Command{ + Name: "version", + FlagSet: flag.NewFlagSet("version", flag.ExitOnError), + UsageCommand: "ginkgo version", + Usage: []string{ + "Print Ginkgo's version", + }, + Command: printVersion, + } +} + +func printVersion([]string, []string) { + fmt.Printf("Ginkgo Version %s\n", config.VERSION) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta.go b/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta.go new file mode 100644 index 00000000..6c485c5b --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta.go @@ -0,0 +1,22 @@ +package watch + +import "sort" + +type Delta struct { + ModifiedPackages []string + + NewSuites []*Suite + RemovedSuites []*Suite + modifiedSuites []*Suite +} + +type DescendingByDelta []*Suite + +func (a DescendingByDelta) Len() int { return len(a) } +func (a DescendingByDelta) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a DescendingByDelta) Less(i, j int) bool { return a[i].Delta() > a[j].Delta() } + +func (d Delta) ModifiedSuites() []*Suite { + sort.Sort(DescendingByDelta(d.modifiedSuites)) + return d.modifiedSuites +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta_tracker.go b/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta_tracker.go new file mode 100644 index 00000000..a628303d --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta_tracker.go @@ -0,0 +1,75 @@ +package watch + +import ( + "fmt" + + "regexp" + + "github.com/onsi/ginkgo/ginkgo/testsuite" +) + +type SuiteErrors map[testsuite.TestSuite]error + +type DeltaTracker struct { + maxDepth int + watchRegExp *regexp.Regexp + suites map[string]*Suite + packageHashes *PackageHashes +} + +func NewDeltaTracker(maxDepth int, watchRegExp *regexp.Regexp) *DeltaTracker { + return &DeltaTracker{ + maxDepth: maxDepth, + watchRegExp: watchRegExp, + packageHashes: NewPackageHashes(watchRegExp), + suites: map[string]*Suite{}, + } +} + +func (d *DeltaTracker) Delta(suites []testsuite.TestSuite) (delta Delta, errors SuiteErrors) { + errors = SuiteErrors{} + delta.ModifiedPackages = d.packageHashes.CheckForChanges() + + providedSuitePaths := map[string]bool{} + for _, suite := range suites { + providedSuitePaths[suite.Path] = true + } + + d.packageHashes.StartTrackingUsage() + + for _, suite := range d.suites { + if providedSuitePaths[suite.Suite.Path] { + if suite.Delta() > 0 { + delta.modifiedSuites = append(delta.modifiedSuites, suite) + } + } else { + delta.RemovedSuites = append(delta.RemovedSuites, suite) + } + } + + d.packageHashes.StopTrackingUsageAndPrune() + + for _, suite := range suites { + _, ok := d.suites[suite.Path] + if !ok { + s, err := NewSuite(suite, d.maxDepth, d.packageHashes) + if err != nil { + errors[suite] = err + continue + } + d.suites[suite.Path] = s + delta.NewSuites = append(delta.NewSuites, s) + } + } + + return delta, errors +} + +func (d *DeltaTracker) WillRun(suite testsuite.TestSuite) error { + s, ok := d.suites[suite.Path] + if !ok { + return fmt.Errorf("unknown suite %s", suite.Path) + } + + return s.MarkAsRunAndRecomputedDependencies(d.maxDepth) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/dependencies.go b/vendor/github.com/onsi/ginkgo/ginkgo/watch/dependencies.go new file mode 100644 index 00000000..f5ddff30 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/watch/dependencies.go @@ -0,0 +1,92 @@ +package watch + +import ( + "go/build" + "regexp" +) + +var ginkgoAndGomegaFilter = regexp.MustCompile(`github\.com/onsi/ginkgo|github\.com/onsi/gomega`) +var ginkgoIntegrationTestFilter = regexp.MustCompile(`github\.com/onsi/ginkgo/integration`) //allow us to integration test this thing + +type Dependencies struct { + deps map[string]int +} + +func NewDependencies(path string, maxDepth int) (Dependencies, error) { + d := Dependencies{ + deps: map[string]int{}, + } + + if maxDepth == 0 { + return d, nil + } + + err := d.seedWithDepsForPackageAtPath(path) + if err != nil { + return d, err + } + + for depth := 1; depth < maxDepth; depth++ { + n := len(d.deps) + d.addDepsForDepth(depth) + if n == len(d.deps) { + break + } + } + + return d, nil +} + +func (d Dependencies) Dependencies() map[string]int { + return d.deps +} + +func (d Dependencies) seedWithDepsForPackageAtPath(path string) error { + pkg, err := build.ImportDir(path, 0) + if err != nil { + return err + } + + d.resolveAndAdd(pkg.Imports, 1) + d.resolveAndAdd(pkg.TestImports, 1) + d.resolveAndAdd(pkg.XTestImports, 1) + + delete(d.deps, pkg.Dir) + return nil +} + +func (d Dependencies) addDepsForDepth(depth int) { + for dep, depDepth := range d.deps { + if depDepth == depth { + d.addDepsForDep(dep, depth+1) + } + } +} + +func (d Dependencies) addDepsForDep(dep string, depth int) { + pkg, err := build.ImportDir(dep, 0) + if err != nil { + println(err.Error()) + return + } + d.resolveAndAdd(pkg.Imports, depth) +} + +func (d Dependencies) resolveAndAdd(deps []string, depth int) { + for _, dep := range deps { + pkg, err := build.Import(dep, ".", 0) + if err != nil { + continue + } + if !pkg.Goroot && (!ginkgoAndGomegaFilter.Match([]byte(pkg.Dir)) || ginkgoIntegrationTestFilter.Match([]byte(pkg.Dir))) { + d.addDepIfNotPresent(pkg.Dir, depth) + } + } +} + +func (d Dependencies) addDepIfNotPresent(dep string, depth int) { + _, ok := d.deps[dep] + if !ok { + d.deps[dep] = depth + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hash.go b/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hash.go new file mode 100644 index 00000000..67e2c1c3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hash.go @@ -0,0 +1,104 @@ +package watch + +import ( + "fmt" + "io/ioutil" + "os" + "regexp" + "time" +) + +var goTestRegExp = regexp.MustCompile(`_test\.go$`) + +type PackageHash struct { + CodeModifiedTime time.Time + TestModifiedTime time.Time + Deleted bool + + path string + codeHash string + testHash string + watchRegExp *regexp.Regexp +} + +func NewPackageHash(path string, watchRegExp *regexp.Regexp) *PackageHash { + p := &PackageHash{ + path: path, + watchRegExp: watchRegExp, + } + + p.codeHash, _, p.testHash, _, p.Deleted = p.computeHashes() + + return p +} + +func (p *PackageHash) CheckForChanges() bool { + codeHash, codeModifiedTime, testHash, testModifiedTime, deleted := p.computeHashes() + + if deleted { + if !p.Deleted { + t := time.Now() + p.CodeModifiedTime = t + p.TestModifiedTime = t + } + p.Deleted = true + return true + } + + modified := false + p.Deleted = false + + if p.codeHash != codeHash { + p.CodeModifiedTime = codeModifiedTime + modified = true + } + if p.testHash != testHash { + p.TestModifiedTime = testModifiedTime + modified = true + } + + p.codeHash = codeHash + p.testHash = testHash + return modified +} + +func (p *PackageHash) computeHashes() (codeHash string, codeModifiedTime time.Time, testHash string, testModifiedTime time.Time, deleted bool) { + infos, err := ioutil.ReadDir(p.path) + + if err != nil { + deleted = true + return + } + + for _, info := range infos { + if info.IsDir() { + continue + } + + if goTestRegExp.Match([]byte(info.Name())) { + testHash += p.hashForFileInfo(info) + if info.ModTime().After(testModifiedTime) { + testModifiedTime = info.ModTime() + } + continue + } + + if p.watchRegExp.Match([]byte(info.Name())) { + codeHash += p.hashForFileInfo(info) + if info.ModTime().After(codeModifiedTime) { + codeModifiedTime = info.ModTime() + } + } + } + + testHash += codeHash + if codeModifiedTime.After(testModifiedTime) { + testModifiedTime = codeModifiedTime + } + + return +} + +func (p *PackageHash) hashForFileInfo(info os.FileInfo) string { + return fmt.Sprintf("%s_%d_%d", info.Name(), info.Size(), info.ModTime().UnixNano()) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hashes.go b/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hashes.go new file mode 100644 index 00000000..b4892beb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hashes.go @@ -0,0 +1,85 @@ +package watch + +import ( + "path/filepath" + "regexp" + "sync" +) + +type PackageHashes struct { + PackageHashes map[string]*PackageHash + usedPaths map[string]bool + watchRegExp *regexp.Regexp + lock *sync.Mutex +} + +func NewPackageHashes(watchRegExp *regexp.Regexp) *PackageHashes { + return &PackageHashes{ + PackageHashes: map[string]*PackageHash{}, + usedPaths: nil, + watchRegExp: watchRegExp, + lock: &sync.Mutex{}, + } +} + +func (p *PackageHashes) CheckForChanges() []string { + p.lock.Lock() + defer p.lock.Unlock() + + modified := []string{} + + for _, packageHash := range p.PackageHashes { + if packageHash.CheckForChanges() { + modified = append(modified, packageHash.path) + } + } + + return modified +} + +func (p *PackageHashes) Add(path string) *PackageHash { + p.lock.Lock() + defer p.lock.Unlock() + + path, _ = filepath.Abs(path) + _, ok := p.PackageHashes[path] + if !ok { + p.PackageHashes[path] = NewPackageHash(path, p.watchRegExp) + } + + if p.usedPaths != nil { + p.usedPaths[path] = true + } + return p.PackageHashes[path] +} + +func (p *PackageHashes) Get(path string) *PackageHash { + p.lock.Lock() + defer p.lock.Unlock() + + path, _ = filepath.Abs(path) + if p.usedPaths != nil { + p.usedPaths[path] = true + } + return p.PackageHashes[path] +} + +func (p *PackageHashes) StartTrackingUsage() { + p.lock.Lock() + defer p.lock.Unlock() + + p.usedPaths = map[string]bool{} +} + +func (p *PackageHashes) StopTrackingUsageAndPrune() { + p.lock.Lock() + defer p.lock.Unlock() + + for path := range p.PackageHashes { + if !p.usedPaths[path] { + delete(p.PackageHashes, path) + } + } + + p.usedPaths = nil +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/suite.go b/vendor/github.com/onsi/ginkgo/ginkgo/watch/suite.go new file mode 100644 index 00000000..5deaba7c --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/watch/suite.go @@ -0,0 +1,87 @@ +package watch + +import ( + "fmt" + "math" + "time" + + "github.com/onsi/ginkgo/ginkgo/testsuite" +) + +type Suite struct { + Suite testsuite.TestSuite + RunTime time.Time + Dependencies Dependencies + + sharedPackageHashes *PackageHashes +} + +func NewSuite(suite testsuite.TestSuite, maxDepth int, sharedPackageHashes *PackageHashes) (*Suite, error) { + deps, err := NewDependencies(suite.Path, maxDepth) + if err != nil { + return nil, err + } + + sharedPackageHashes.Add(suite.Path) + for dep := range deps.Dependencies() { + sharedPackageHashes.Add(dep) + } + + return &Suite{ + Suite: suite, + Dependencies: deps, + + sharedPackageHashes: sharedPackageHashes, + }, nil +} + +func (s *Suite) Delta() float64 { + delta := s.delta(s.Suite.Path, true, 0) * 1000 + for dep, depth := range s.Dependencies.Dependencies() { + delta += s.delta(dep, false, depth) + } + return delta +} + +func (s *Suite) MarkAsRunAndRecomputedDependencies(maxDepth int) error { + s.RunTime = time.Now() + + deps, err := NewDependencies(s.Suite.Path, maxDepth) + if err != nil { + return err + } + + s.sharedPackageHashes.Add(s.Suite.Path) + for dep := range deps.Dependencies() { + s.sharedPackageHashes.Add(dep) + } + + s.Dependencies = deps + + return nil +} + +func (s *Suite) Description() string { + numDeps := len(s.Dependencies.Dependencies()) + pluralizer := "ies" + if numDeps == 1 { + pluralizer = "y" + } + return fmt.Sprintf("%s [%d dependenc%s]", s.Suite.Path, numDeps, pluralizer) +} + +func (s *Suite) delta(packagePath string, includeTests bool, depth int) float64 { + return math.Max(float64(s.dt(packagePath, includeTests)), 0) / float64(depth+1) +} + +func (s *Suite) dt(packagePath string, includeTests bool) time.Duration { + packageHash := s.sharedPackageHashes.Get(packagePath) + var modifiedTime time.Time + if includeTests { + modifiedTime = packageHash.TestModifiedTime + } else { + modifiedTime = packageHash.CodeModifiedTime + } + + return modifiedTime.Sub(s.RunTime) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/watch_command.go new file mode 100644 index 00000000..a6ef053c --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/ginkgo/watch_command.go @@ -0,0 +1,175 @@ +package main + +import ( + "flag" + "fmt" + "regexp" + "time" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/ginkgo/interrupthandler" + "github.com/onsi/ginkgo/ginkgo/testrunner" + "github.com/onsi/ginkgo/ginkgo/testsuite" + "github.com/onsi/ginkgo/ginkgo/watch" + colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" +) + +func BuildWatchCommand() *Command { + commandFlags := NewWatchCommandFlags(flag.NewFlagSet("watch", flag.ExitOnError)) + interruptHandler := interrupthandler.NewInterruptHandler() + notifier := NewNotifier(commandFlags) + watcher := &SpecWatcher{ + commandFlags: commandFlags, + notifier: notifier, + interruptHandler: interruptHandler, + suiteRunner: NewSuiteRunner(notifier, interruptHandler), + } + + return &Command{ + Name: "watch", + FlagSet: commandFlags.FlagSet, + UsageCommand: "ginkgo watch -- ", + Usage: []string{ + "Watches the tests in the passed in and runs them when changes occur.", + "Any arguments after -- will be passed to the test.", + }, + Command: watcher.WatchSpecs, + SuppressFlagDocumentation: true, + FlagDocSubstitute: []string{ + "Accepts all the flags that the ginkgo command accepts except for --keepGoing and --untilItFails", + }, + } +} + +type SpecWatcher struct { + commandFlags *RunWatchAndBuildCommandFlags + notifier *Notifier + interruptHandler *interrupthandler.InterruptHandler + suiteRunner *SuiteRunner +} + +func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) { + w.commandFlags.computeNodes() + w.notifier.VerifyNotificationsAreAvailable() + + w.WatchSuites(args, additionalArgs) +} + +func (w *SpecWatcher) runnersForSuites(suites []testsuite.TestSuite, additionalArgs []string) []*testrunner.TestRunner { + runners := []*testrunner.TestRunner{} + + for _, suite := range suites { + runners = append(runners, testrunner.New(suite, w.commandFlags.NumCPU, w.commandFlags.ParallelStream, w.commandFlags.Timeout, w.commandFlags.GoOpts, additionalArgs)) + } + + return runners +} + +func (w *SpecWatcher) WatchSuites(args []string, additionalArgs []string) { + suites, _ := findSuites(args, w.commandFlags.Recurse, w.commandFlags.SkipPackage, false) + + if len(suites) == 0 { + complainAndQuit("Found no test suites") + } + + fmt.Printf("Identified %d test %s. Locating dependencies to a depth of %d (this may take a while)...\n", len(suites), pluralizedWord("suite", "suites", len(suites)), w.commandFlags.Depth) + deltaTracker := watch.NewDeltaTracker(w.commandFlags.Depth, regexp.MustCompile(w.commandFlags.WatchRegExp)) + delta, errors := deltaTracker.Delta(suites) + + fmt.Printf("Watching %d %s:\n", len(delta.NewSuites), pluralizedWord("suite", "suites", len(delta.NewSuites))) + for _, suite := range delta.NewSuites { + fmt.Println(" " + suite.Description()) + } + + for suite, err := range errors { + fmt.Printf("Failed to watch %s: %s\n", suite.PackageName, err) + } + + if len(suites) == 1 { + runners := w.runnersForSuites(suites, additionalArgs) + w.suiteRunner.RunSuites(runners, w.commandFlags.NumCompilers, true, nil) + runners[0].CleanUp() + } + + ticker := time.NewTicker(time.Second) + + for { + select { + case <-ticker.C: + suites, _ := findSuites(args, w.commandFlags.Recurse, w.commandFlags.SkipPackage, false) + delta, _ := deltaTracker.Delta(suites) + coloredStream := colorable.NewColorableStdout() + + suitesToRun := []testsuite.TestSuite{} + + if len(delta.NewSuites) > 0 { + fmt.Fprintf(coloredStream, greenColor+"Detected %d new %s:\n"+defaultStyle, len(delta.NewSuites), pluralizedWord("suite", "suites", len(delta.NewSuites))) + for _, suite := range delta.NewSuites { + suitesToRun = append(suitesToRun, suite.Suite) + fmt.Fprintln(coloredStream, " "+suite.Description()) + } + } + + modifiedSuites := delta.ModifiedSuites() + if len(modifiedSuites) > 0 { + fmt.Fprintln(coloredStream, greenColor+"\nDetected changes in:"+defaultStyle) + for _, pkg := range delta.ModifiedPackages { + fmt.Fprintln(coloredStream, " "+pkg) + } + fmt.Fprintf(coloredStream, greenColor+"Will run %d %s:\n"+defaultStyle, len(modifiedSuites), pluralizedWord("suite", "suites", len(modifiedSuites))) + for _, suite := range modifiedSuites { + suitesToRun = append(suitesToRun, suite.Suite) + fmt.Fprintln(coloredStream, " "+suite.Description()) + } + fmt.Fprintln(coloredStream, "") + } + + if len(suitesToRun) > 0 { + w.UpdateSeed() + w.ComputeSuccinctMode(len(suitesToRun)) + runners := w.runnersForSuites(suitesToRun, additionalArgs) + result, _ := w.suiteRunner.RunSuites(runners, w.commandFlags.NumCompilers, true, func(suite testsuite.TestSuite) { + deltaTracker.WillRun(suite) + }) + for _, runner := range runners { + runner.CleanUp() + } + if !w.interruptHandler.WasInterrupted() { + color := redColor + if result.Passed { + color = greenColor + } + fmt.Fprintln(coloredStream, color+"\nDone. Resuming watch..."+defaultStyle) + } + } + + case <-w.interruptHandler.C: + return + } + } +} + +func (w *SpecWatcher) ComputeSuccinctMode(numSuites int) { + if config.DefaultReporterConfig.Verbose { + config.DefaultReporterConfig.Succinct = false + return + } + + if w.commandFlags.wasSet("succinct") { + return + } + + if numSuites == 1 { + config.DefaultReporterConfig.Succinct = false + } + + if numSuites > 1 { + config.DefaultReporterConfig.Succinct = true + } +} + +func (w *SpecWatcher) UpdateSeed() { + if !w.commandFlags.wasSet("seed") { + config.GinkgoConfig.RandomSeed = time.Now().Unix() + } +} diff --git a/vendor/github.com/onsi/ginkgo/internal/codelocation/code_location.go b/vendor/github.com/onsi/ginkgo/internal/codelocation/code_location.go new file mode 100644 index 00000000..aa89d6cb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/codelocation/code_location.go @@ -0,0 +1,48 @@ +package codelocation + +import ( + "regexp" + "runtime" + "runtime/debug" + "strings" + + "github.com/onsi/ginkgo/types" +) + +func New(skip int) types.CodeLocation { + _, file, line, _ := runtime.Caller(skip + 1) + stackTrace := PruneStack(string(debug.Stack()), skip+1) + return types.CodeLocation{FileName: file, LineNumber: line, FullStackTrace: stackTrace} +} + +// PruneStack removes references to functions that are internal to Ginkgo +// and the Go runtime from a stack string and a certain number of stack entries +// at the beginning of the stack. The stack string has the format +// as returned by runtime/debug.Stack. The leading goroutine information is +// optional and always removed if present. Beware that runtime/debug.Stack +// adds itself as first entry, so typically skip must be >= 1 to remove that +// entry. +func PruneStack(fullStackTrace string, skip int) string { + stack := strings.Split(fullStackTrace, "\n") + // Ensure that the even entries are the method names and the + // the odd entries the source code information. + if len(stack) > 0 && strings.HasPrefix(stack[0], "goroutine ") { + // Ignore "goroutine 29 [running]:" line. + stack = stack[1:] + } + // The "+1" is for skipping over the initial entry, which is + // runtime/debug.Stack() itself. + if len(stack) > 2*(skip+1) { + stack = stack[2*(skip+1):] + } + prunedStack := []string{} + re := regexp.MustCompile(`\/ginkgo\/|\/pkg\/testing\/|\/pkg\/runtime\/`) + for i := 0; i < len(stack)/2; i++ { + // We filter out based on the source code file name. + if !re.Match([]byte(stack[i*2+1])) { + prunedStack = append(prunedStack, stack[i*2]) + prunedStack = append(prunedStack, stack[i*2+1]) + } + } + return strings.Join(prunedStack, "\n") +} diff --git a/vendor/github.com/onsi/ginkgo/internal/containernode/container_node.go b/vendor/github.com/onsi/ginkgo/internal/containernode/container_node.go new file mode 100644 index 00000000..0737746d --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/containernode/container_node.go @@ -0,0 +1,151 @@ +package containernode + +import ( + "math/rand" + "sort" + + "github.com/onsi/ginkgo/internal/leafnodes" + "github.com/onsi/ginkgo/types" +) + +type subjectOrContainerNode struct { + containerNode *ContainerNode + subjectNode leafnodes.SubjectNode +} + +func (n subjectOrContainerNode) text() string { + if n.containerNode != nil { + return n.containerNode.Text() + } else { + return n.subjectNode.Text() + } +} + +type CollatedNodes struct { + Containers []*ContainerNode + Subject leafnodes.SubjectNode +} + +type ContainerNode struct { + text string + flag types.FlagType + codeLocation types.CodeLocation + + setupNodes []leafnodes.BasicNode + subjectAndContainerNodes []subjectOrContainerNode +} + +func New(text string, flag types.FlagType, codeLocation types.CodeLocation) *ContainerNode { + return &ContainerNode{ + text: text, + flag: flag, + codeLocation: codeLocation, + } +} + +func (container *ContainerNode) Shuffle(r *rand.Rand) { + sort.Sort(container) + permutation := r.Perm(len(container.subjectAndContainerNodes)) + shuffledNodes := make([]subjectOrContainerNode, len(container.subjectAndContainerNodes)) + for i, j := range permutation { + shuffledNodes[i] = container.subjectAndContainerNodes[j] + } + container.subjectAndContainerNodes = shuffledNodes +} + +func (node *ContainerNode) BackPropagateProgrammaticFocus() bool { + if node.flag == types.FlagTypePending { + return false + } + + shouldUnfocus := false + for _, subjectOrContainerNode := range node.subjectAndContainerNodes { + if subjectOrContainerNode.containerNode != nil { + shouldUnfocus = subjectOrContainerNode.containerNode.BackPropagateProgrammaticFocus() || shouldUnfocus + } else { + shouldUnfocus = (subjectOrContainerNode.subjectNode.Flag() == types.FlagTypeFocused) || shouldUnfocus + } + } + + if shouldUnfocus { + if node.flag == types.FlagTypeFocused { + node.flag = types.FlagTypeNone + } + return true + } + + return node.flag == types.FlagTypeFocused +} + +func (node *ContainerNode) Collate() []CollatedNodes { + return node.collate([]*ContainerNode{}) +} + +func (node *ContainerNode) collate(enclosingContainers []*ContainerNode) []CollatedNodes { + collated := make([]CollatedNodes, 0) + + containers := make([]*ContainerNode, len(enclosingContainers)) + copy(containers, enclosingContainers) + containers = append(containers, node) + + for _, subjectOrContainer := range node.subjectAndContainerNodes { + if subjectOrContainer.containerNode != nil { + collated = append(collated, subjectOrContainer.containerNode.collate(containers)...) + } else { + collated = append(collated, CollatedNodes{ + Containers: containers, + Subject: subjectOrContainer.subjectNode, + }) + } + } + + return collated +} + +func (node *ContainerNode) PushContainerNode(container *ContainerNode) { + node.subjectAndContainerNodes = append(node.subjectAndContainerNodes, subjectOrContainerNode{containerNode: container}) +} + +func (node *ContainerNode) PushSubjectNode(subject leafnodes.SubjectNode) { + node.subjectAndContainerNodes = append(node.subjectAndContainerNodes, subjectOrContainerNode{subjectNode: subject}) +} + +func (node *ContainerNode) PushSetupNode(setupNode leafnodes.BasicNode) { + node.setupNodes = append(node.setupNodes, setupNode) +} + +func (node *ContainerNode) SetupNodesOfType(nodeType types.SpecComponentType) []leafnodes.BasicNode { + nodes := []leafnodes.BasicNode{} + for _, setupNode := range node.setupNodes { + if setupNode.Type() == nodeType { + nodes = append(nodes, setupNode) + } + } + return nodes +} + +func (node *ContainerNode) Text() string { + return node.text +} + +func (node *ContainerNode) CodeLocation() types.CodeLocation { + return node.codeLocation +} + +func (node *ContainerNode) Flag() types.FlagType { + return node.flag +} + +//sort.Interface + +func (node *ContainerNode) Len() int { + return len(node.subjectAndContainerNodes) +} + +func (node *ContainerNode) Less(i, j int) bool { + return node.subjectAndContainerNodes[i].text() < node.subjectAndContainerNodes[j].text() +} + +func (node *ContainerNode) Swap(i, j int) { + node.subjectAndContainerNodes[i], node.subjectAndContainerNodes[j] = node.subjectAndContainerNodes[j], node.subjectAndContainerNodes[i] +} diff --git a/vendor/github.com/onsi/ginkgo/internal/failer/failer.go b/vendor/github.com/onsi/ginkgo/internal/failer/failer.go new file mode 100644 index 00000000..678ea251 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/failer/failer.go @@ -0,0 +1,92 @@ +package failer + +import ( + "fmt" + "sync" + + "github.com/onsi/ginkgo/types" +) + +type Failer struct { + lock *sync.Mutex + failure types.SpecFailure + state types.SpecState +} + +func New() *Failer { + return &Failer{ + lock: &sync.Mutex{}, + state: types.SpecStatePassed, + } +} + +func (f *Failer) Panic(location types.CodeLocation, forwardedPanic interface{}) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.state == types.SpecStatePassed { + f.state = types.SpecStatePanicked + f.failure = types.SpecFailure{ + Message: "Test Panicked", + Location: location, + ForwardedPanic: fmt.Sprintf("%v", forwardedPanic), + } + } +} + +func (f *Failer) Timeout(location types.CodeLocation) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.state == types.SpecStatePassed { + f.state = types.SpecStateTimedOut + f.failure = types.SpecFailure{ + Message: "Timed out", + Location: location, + } + } +} + +func (f *Failer) Fail(message string, location types.CodeLocation) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.state == types.SpecStatePassed { + f.state = types.SpecStateFailed + f.failure = types.SpecFailure{ + Message: message, + Location: location, + } + } +} + +func (f *Failer) Drain(componentType types.SpecComponentType, componentIndex int, componentCodeLocation types.CodeLocation) (types.SpecFailure, types.SpecState) { + f.lock.Lock() + defer f.lock.Unlock() + + failure := f.failure + outcome := f.state + if outcome != types.SpecStatePassed { + failure.ComponentType = componentType + failure.ComponentIndex = componentIndex + failure.ComponentCodeLocation = componentCodeLocation + } + + f.state = types.SpecStatePassed + f.failure = types.SpecFailure{} + + return failure, outcome +} + +func (f *Failer) Skip(message string, location types.CodeLocation) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.state == types.SpecStatePassed { + f.state = types.SpecStateSkipped + f.failure = types.SpecFailure{ + Message: message, + Location: location, + } + } +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/benchmarker.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/benchmarker.go new file mode 100644 index 00000000..393901e1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/benchmarker.go @@ -0,0 +1,103 @@ +package leafnodes + +import ( + "math" + "time" + + "sync" + + "github.com/onsi/ginkgo/types" +) + +type benchmarker struct { + mu sync.Mutex + measurements map[string]*types.SpecMeasurement + orderCounter int +} + +func newBenchmarker() *benchmarker { + return &benchmarker{ + measurements: make(map[string]*types.SpecMeasurement), + } +} + +func (b *benchmarker) Time(name string, body func(), info ...interface{}) (elapsedTime time.Duration) { + t := time.Now() + body() + elapsedTime = time.Since(t) + + b.mu.Lock() + defer b.mu.Unlock() + measurement := b.getMeasurement(name, "Fastest Time", "Slowest Time", "Average Time", "s", 3, info...) + measurement.Results = append(measurement.Results, elapsedTime.Seconds()) + + return +} + +func (b *benchmarker) RecordValue(name string, value float64, info ...interface{}) { + b.mu.Lock() + measurement := b.getMeasurement(name, "Smallest", " Largest", " Average", "", 3, info...) + defer b.mu.Unlock() + measurement.Results = append(measurement.Results, value) +} + +func (b *benchmarker) RecordValueWithPrecision(name string, value float64, units string, precision int, info ...interface{}) { + b.mu.Lock() + measurement := b.getMeasurement(name, "Smallest", " Largest", " Average", units, precision, info...) + defer b.mu.Unlock() + measurement.Results = append(measurement.Results, value) +} + +func (b *benchmarker) getMeasurement(name string, smallestLabel string, largestLabel string, averageLabel string, units string, precision int, info ...interface{}) *types.SpecMeasurement { + measurement, ok := b.measurements[name] + if !ok { + var computedInfo interface{} + computedInfo = nil + if len(info) > 0 { + computedInfo = info[0] + } + measurement = &types.SpecMeasurement{ + Name: name, + Info: computedInfo, + Order: b.orderCounter, + SmallestLabel: smallestLabel, + LargestLabel: largestLabel, + AverageLabel: averageLabel, + Units: units, + Precision: precision, + Results: make([]float64, 0), + } + b.measurements[name] = measurement + b.orderCounter++ + } + + return measurement +} + +func (b *benchmarker) measurementsReport() map[string]*types.SpecMeasurement { + b.mu.Lock() + defer b.mu.Unlock() + for _, measurement := range b.measurements { + measurement.Smallest = math.MaxFloat64 + measurement.Largest = -math.MaxFloat64 + sum := float64(0) + sumOfSquares := float64(0) + + for _, result := range measurement.Results { + if result > measurement.Largest { + measurement.Largest = result + } + if result < measurement.Smallest { + measurement.Smallest = result + } + sum += result + sumOfSquares += result * result + } + + n := float64(len(measurement.Results)) + measurement.Average = sum / n + measurement.StdDeviation = math.Sqrt(sumOfSquares/n - (sum/n)*(sum/n)) + } + + return b.measurements +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/interfaces.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/interfaces.go new file mode 100644 index 00000000..8c3902d6 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/interfaces.go @@ -0,0 +1,19 @@ +package leafnodes + +import ( + "github.com/onsi/ginkgo/types" +) + +type BasicNode interface { + Type() types.SpecComponentType + Run() (types.SpecState, types.SpecFailure) + CodeLocation() types.CodeLocation +} + +type SubjectNode interface { + BasicNode + + Text() string + Flag() types.FlagType + Samples() int +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/it_node.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/it_node.go new file mode 100644 index 00000000..6eded7b7 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/it_node.go @@ -0,0 +1,47 @@ +package leafnodes + +import ( + "time" + + "github.com/onsi/ginkgo/internal/failer" + "github.com/onsi/ginkgo/types" +) + +type ItNode struct { + runner *runner + + flag types.FlagType + text string +} + +func NewItNode(text string, body interface{}, flag types.FlagType, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *ItNode { + return &ItNode{ + runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeIt, componentIndex), + flag: flag, + text: text, + } +} + +func (node *ItNode) Run() (outcome types.SpecState, failure types.SpecFailure) { + return node.runner.run() +} + +func (node *ItNode) Type() types.SpecComponentType { + return types.SpecComponentTypeIt +} + +func (node *ItNode) Text() string { + return node.text +} + +func (node *ItNode) Flag() types.FlagType { + return node.flag +} + +func (node *ItNode) CodeLocation() types.CodeLocation { + return node.runner.codeLocation +} + +func (node *ItNode) Samples() int { + return 1 +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/measure_node.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/measure_node.go new file mode 100644 index 00000000..3ab9a6d5 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/measure_node.go @@ -0,0 +1,62 @@ +package leafnodes + +import ( + "reflect" + + "github.com/onsi/ginkgo/internal/failer" + "github.com/onsi/ginkgo/types" +) + +type MeasureNode struct { + runner *runner + + text string + flag types.FlagType + samples int + benchmarker *benchmarker +} + +func NewMeasureNode(text string, body interface{}, flag types.FlagType, codeLocation types.CodeLocation, samples int, failer *failer.Failer, componentIndex int) *MeasureNode { + benchmarker := newBenchmarker() + + wrappedBody := func() { + reflect.ValueOf(body).Call([]reflect.Value{reflect.ValueOf(benchmarker)}) + } + + return &MeasureNode{ + runner: newRunner(wrappedBody, codeLocation, 0, failer, types.SpecComponentTypeMeasure, componentIndex), + + text: text, + flag: flag, + samples: samples, + benchmarker: benchmarker, + } +} + +func (node *MeasureNode) Run() (outcome types.SpecState, failure types.SpecFailure) { + return node.runner.run() +} + +func (node *MeasureNode) MeasurementsReport() map[string]*types.SpecMeasurement { + return node.benchmarker.measurementsReport() +} + +func (node *MeasureNode) Type() types.SpecComponentType { + return types.SpecComponentTypeMeasure +} + +func (node *MeasureNode) Text() string { + return node.text +} + +func (node *MeasureNode) Flag() types.FlagType { + return node.flag +} + +func (node *MeasureNode) CodeLocation() types.CodeLocation { + return node.runner.codeLocation +} + +func (node *MeasureNode) Samples() int { + return node.samples +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/runner.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/runner.go new file mode 100644 index 00000000..16cb66c3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/runner.go @@ -0,0 +1,117 @@ +package leafnodes + +import ( + "fmt" + "reflect" + "time" + + "github.com/onsi/ginkgo/internal/codelocation" + "github.com/onsi/ginkgo/internal/failer" + "github.com/onsi/ginkgo/types" +) + +type runner struct { + isAsync bool + asyncFunc func(chan<- interface{}) + syncFunc func() + codeLocation types.CodeLocation + timeoutThreshold time.Duration + nodeType types.SpecComponentType + componentIndex int + failer *failer.Failer +} + +func newRunner(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, nodeType types.SpecComponentType, componentIndex int) *runner { + bodyType := reflect.TypeOf(body) + if bodyType.Kind() != reflect.Func { + panic(fmt.Sprintf("Expected a function but got something else at %v", codeLocation)) + } + + runner := &runner{ + codeLocation: codeLocation, + timeoutThreshold: timeout, + failer: failer, + nodeType: nodeType, + componentIndex: componentIndex, + } + + switch bodyType.NumIn() { + case 0: + runner.syncFunc = body.(func()) + return runner + case 1: + if !(bodyType.In(0).Kind() == reflect.Chan && bodyType.In(0).Elem().Kind() == reflect.Interface) { + panic(fmt.Sprintf("Must pass a Done channel to function at %v", codeLocation)) + } + + wrappedBody := func(done chan<- interface{}) { + bodyValue := reflect.ValueOf(body) + bodyValue.Call([]reflect.Value{reflect.ValueOf(done)}) + } + + runner.isAsync = true + runner.asyncFunc = wrappedBody + return runner + } + + panic(fmt.Sprintf("Too many arguments to function at %v", codeLocation)) +} + +func (r *runner) run() (outcome types.SpecState, failure types.SpecFailure) { + if r.isAsync { + return r.runAsync() + } else { + return r.runSync() + } +} + +func (r *runner) runAsync() (outcome types.SpecState, failure types.SpecFailure) { + done := make(chan interface{}, 1) + + go func() { + finished := false + + defer func() { + if e := recover(); e != nil || !finished { + r.failer.Panic(codelocation.New(2), e) + select { + case <-done: + break + default: + close(done) + } + } + }() + + r.asyncFunc(done) + finished = true + }() + + // If this goroutine gets no CPU time before the select block, + // the <-done case may complete even if the test took longer than the timeoutThreshold. + // This can cause flaky behaviour, but we haven't seen it in the wild. + select { + case <-done: + case <-time.After(r.timeoutThreshold): + r.failer.Timeout(r.codeLocation) + } + + failure, outcome = r.failer.Drain(r.nodeType, r.componentIndex, r.codeLocation) + return +} +func (r *runner) runSync() (outcome types.SpecState, failure types.SpecFailure) { + finished := false + + defer func() { + if e := recover(); e != nil || !finished { + r.failer.Panic(codelocation.New(2), e) + } + + failure, outcome = r.failer.Drain(r.nodeType, r.componentIndex, r.codeLocation) + }() + + r.syncFunc() + finished = true + + return +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/setup_nodes.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/setup_nodes.go new file mode 100644 index 00000000..e3e9cb7c --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/setup_nodes.go @@ -0,0 +1,48 @@ +package leafnodes + +import ( + "time" + + "github.com/onsi/ginkgo/internal/failer" + "github.com/onsi/ginkgo/types" +) + +type SetupNode struct { + runner *runner +} + +func (node *SetupNode) Run() (outcome types.SpecState, failure types.SpecFailure) { + return node.runner.run() +} + +func (node *SetupNode) Type() types.SpecComponentType { + return node.runner.nodeType +} + +func (node *SetupNode) CodeLocation() types.CodeLocation { + return node.runner.codeLocation +} + +func NewBeforeEachNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *SetupNode { + return &SetupNode{ + runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeBeforeEach, componentIndex), + } +} + +func NewAfterEachNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *SetupNode { + return &SetupNode{ + runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeAfterEach, componentIndex), + } +} + +func NewJustBeforeEachNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *SetupNode { + return &SetupNode{ + runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeJustBeforeEach, componentIndex), + } +} + +func NewJustAfterEachNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *SetupNode { + return &SetupNode{ + runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeJustAfterEach, componentIndex), + } +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/suite_nodes.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/suite_nodes.go new file mode 100644 index 00000000..80f16ed7 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/suite_nodes.go @@ -0,0 +1,55 @@ +package leafnodes + +import ( + "time" + + "github.com/onsi/ginkgo/internal/failer" + "github.com/onsi/ginkgo/types" +) + +type SuiteNode interface { + Run(parallelNode int, parallelTotal int, syncHost string) bool + Passed() bool + Summary() *types.SetupSummary +} + +type simpleSuiteNode struct { + runner *runner + outcome types.SpecState + failure types.SpecFailure + runTime time.Duration +} + +func (node *simpleSuiteNode) Run(parallelNode int, parallelTotal int, syncHost string) bool { + t := time.Now() + node.outcome, node.failure = node.runner.run() + node.runTime = time.Since(t) + + return node.outcome == types.SpecStatePassed +} + +func (node *simpleSuiteNode) Passed() bool { + return node.outcome == types.SpecStatePassed +} + +func (node *simpleSuiteNode) Summary() *types.SetupSummary { + return &types.SetupSummary{ + ComponentType: node.runner.nodeType, + CodeLocation: node.runner.codeLocation, + State: node.outcome, + RunTime: node.runTime, + Failure: node.failure, + } +} + +func NewBeforeSuiteNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer) SuiteNode { + return &simpleSuiteNode{ + runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeBeforeSuite, 0), + } +} + +func NewAfterSuiteNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer) SuiteNode { + return &simpleSuiteNode{ + runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeAfterSuite, 0), + } +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_after_suite_node.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_after_suite_node.go new file mode 100644 index 00000000..a721d0cf --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_after_suite_node.go @@ -0,0 +1,90 @@ +package leafnodes + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "time" + + "github.com/onsi/ginkgo/internal/failer" + "github.com/onsi/ginkgo/types" +) + +type synchronizedAfterSuiteNode struct { + runnerA *runner + runnerB *runner + + outcome types.SpecState + failure types.SpecFailure + runTime time.Duration +} + +func NewSynchronizedAfterSuiteNode(bodyA interface{}, bodyB interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer) SuiteNode { + return &synchronizedAfterSuiteNode{ + runnerA: newRunner(bodyA, codeLocation, timeout, failer, types.SpecComponentTypeAfterSuite, 0), + runnerB: newRunner(bodyB, codeLocation, timeout, failer, types.SpecComponentTypeAfterSuite, 0), + } +} + +func (node *synchronizedAfterSuiteNode) Run(parallelNode int, parallelTotal int, syncHost string) bool { + node.outcome, node.failure = node.runnerA.run() + + if parallelNode == 1 { + if parallelTotal > 1 { + node.waitUntilOtherNodesAreDone(syncHost) + } + + outcome, failure := node.runnerB.run() + + if node.outcome == types.SpecStatePassed { + node.outcome, node.failure = outcome, failure + } + } + + return node.outcome == types.SpecStatePassed +} + +func (node *synchronizedAfterSuiteNode) Passed() bool { + return node.outcome == types.SpecStatePassed +} + +func (node *synchronizedAfterSuiteNode) Summary() *types.SetupSummary { + return &types.SetupSummary{ + ComponentType: node.runnerA.nodeType, + CodeLocation: node.runnerA.codeLocation, + State: node.outcome, + RunTime: node.runTime, + Failure: node.failure, + } +} + +func (node *synchronizedAfterSuiteNode) waitUntilOtherNodesAreDone(syncHost string) { + for { + if node.canRun(syncHost) { + return + } + + time.Sleep(50 * time.Millisecond) + } +} + +func (node *synchronizedAfterSuiteNode) canRun(syncHost string) bool { + resp, err := http.Get(syncHost + "/RemoteAfterSuiteData") + if err != nil || resp.StatusCode != http.StatusOK { + return false + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return false + } + resp.Body.Close() + + afterSuiteData := types.RemoteAfterSuiteData{} + err = json.Unmarshal(body, &afterSuiteData) + if err != nil { + return false + } + + return afterSuiteData.CanRun +} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_before_suite_node.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_before_suite_node.go new file mode 100644 index 00000000..d5c88931 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_before_suite_node.go @@ -0,0 +1,181 @@ +package leafnodes + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "reflect" + "time" + + "github.com/onsi/ginkgo/internal/failer" + "github.com/onsi/ginkgo/types" +) + +type synchronizedBeforeSuiteNode struct { + runnerA *runner + runnerB *runner + + data []byte + + outcome types.SpecState + failure types.SpecFailure + runTime time.Duration +} + +func NewSynchronizedBeforeSuiteNode(bodyA interface{}, bodyB interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer) SuiteNode { + node := &synchronizedBeforeSuiteNode{} + + node.runnerA = newRunner(node.wrapA(bodyA), codeLocation, timeout, failer, types.SpecComponentTypeBeforeSuite, 0) + node.runnerB = newRunner(node.wrapB(bodyB), codeLocation, timeout, failer, types.SpecComponentTypeBeforeSuite, 0) + + return node +} + +func (node *synchronizedBeforeSuiteNode) Run(parallelNode int, parallelTotal int, syncHost string) bool { + t := time.Now() + defer func() { + node.runTime = time.Since(t) + }() + + if parallelNode == 1 { + node.outcome, node.failure = node.runA(parallelTotal, syncHost) + } else { + node.outcome, node.failure = node.waitForA(syncHost) + } + + if node.outcome != types.SpecStatePassed { + return false + } + node.outcome, node.failure = node.runnerB.run() + + return node.outcome == types.SpecStatePassed +} + +func (node *synchronizedBeforeSuiteNode) runA(parallelTotal int, syncHost string) (types.SpecState, types.SpecFailure) { + outcome, failure := node.runnerA.run() + + if parallelTotal > 1 { + state := types.RemoteBeforeSuiteStatePassed + if outcome != types.SpecStatePassed { + state = types.RemoteBeforeSuiteStateFailed + } + json := (types.RemoteBeforeSuiteData{ + Data: node.data, + State: state, + }).ToJSON() + http.Post(syncHost+"/BeforeSuiteState", "application/json", bytes.NewBuffer(json)) + } + + return outcome, failure +} + +func (node *synchronizedBeforeSuiteNode) waitForA(syncHost string) (types.SpecState, types.SpecFailure) { + failure := func(message string) types.SpecFailure { + return types.SpecFailure{ + Message: message, + Location: node.runnerA.codeLocation, + ComponentType: node.runnerA.nodeType, + ComponentIndex: node.runnerA.componentIndex, + ComponentCodeLocation: node.runnerA.codeLocation, + } + } + for { + resp, err := http.Get(syncHost + "/BeforeSuiteState") + if err != nil || resp.StatusCode != http.StatusOK { + return types.SpecStateFailed, failure("Failed to fetch BeforeSuite state") + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return types.SpecStateFailed, failure("Failed to read BeforeSuite state") + } + resp.Body.Close() + + beforeSuiteData := types.RemoteBeforeSuiteData{} + err = json.Unmarshal(body, &beforeSuiteData) + if err != nil { + return types.SpecStateFailed, failure("Failed to decode BeforeSuite state") + } + + switch beforeSuiteData.State { + case types.RemoteBeforeSuiteStatePassed: + node.data = beforeSuiteData.Data + return types.SpecStatePassed, types.SpecFailure{} + case types.RemoteBeforeSuiteStateFailed: + return types.SpecStateFailed, failure("BeforeSuite on Node 1 failed") + case types.RemoteBeforeSuiteStateDisappeared: + return types.SpecStateFailed, failure("Node 1 disappeared before completing BeforeSuite") + } + + time.Sleep(50 * time.Millisecond) + } +} + +func (node *synchronizedBeforeSuiteNode) Passed() bool { + return node.outcome == types.SpecStatePassed +} + +func (node *synchronizedBeforeSuiteNode) Summary() *types.SetupSummary { + return &types.SetupSummary{ + ComponentType: node.runnerA.nodeType, + CodeLocation: node.runnerA.codeLocation, + State: node.outcome, + RunTime: node.runTime, + Failure: node.failure, + } +} + +func (node *synchronizedBeforeSuiteNode) wrapA(bodyA interface{}) interface{} { + typeA := reflect.TypeOf(bodyA) + if typeA.Kind() != reflect.Func { + panic("SynchronizedBeforeSuite expects a function as its first argument") + } + + takesNothing := typeA.NumIn() == 0 + takesADoneChannel := typeA.NumIn() == 1 && typeA.In(0).Kind() == reflect.Chan && typeA.In(0).Elem().Kind() == reflect.Interface + returnsBytes := typeA.NumOut() == 1 && typeA.Out(0).Kind() == reflect.Slice && typeA.Out(0).Elem().Kind() == reflect.Uint8 + + if !((takesNothing || takesADoneChannel) && returnsBytes) { + panic("SynchronizedBeforeSuite's first argument should be a function that returns []byte and either takes no arguments or takes a Done channel.") + } + + if takesADoneChannel { + return func(done chan<- interface{}) { + out := reflect.ValueOf(bodyA).Call([]reflect.Value{reflect.ValueOf(done)}) + node.data = out[0].Interface().([]byte) + } + } + + return func() { + out := reflect.ValueOf(bodyA).Call([]reflect.Value{}) + node.data = out[0].Interface().([]byte) + } +} + +func (node *synchronizedBeforeSuiteNode) wrapB(bodyB interface{}) interface{} { + typeB := reflect.TypeOf(bodyB) + if typeB.Kind() != reflect.Func { + panic("SynchronizedBeforeSuite expects a function as its second argument") + } + + returnsNothing := typeB.NumOut() == 0 + takesBytesOnly := typeB.NumIn() == 1 && typeB.In(0).Kind() == reflect.Slice && typeB.In(0).Elem().Kind() == reflect.Uint8 + takesBytesAndDone := typeB.NumIn() == 2 && + typeB.In(0).Kind() == reflect.Slice && typeB.In(0).Elem().Kind() == reflect.Uint8 && + typeB.In(1).Kind() == reflect.Chan && typeB.In(1).Elem().Kind() == reflect.Interface + + if !((takesBytesOnly || takesBytesAndDone) && returnsNothing) { + panic("SynchronizedBeforeSuite's second argument should be a function that returns nothing and either takes []byte or ([]byte, Done)") + } + + if takesBytesAndDone { + return func(done chan<- interface{}) { + reflect.ValueOf(bodyB).Call([]reflect.Value{reflect.ValueOf(node.data), reflect.ValueOf(done)}) + } + } + + return func() { + reflect.ValueOf(bodyB).Call([]reflect.Value{reflect.ValueOf(node.data)}) + } +} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/aggregator.go b/vendor/github.com/onsi/ginkgo/internal/remote/aggregator.go new file mode 100644 index 00000000..992437d9 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/remote/aggregator.go @@ -0,0 +1,249 @@ +/* + +Aggregator is a reporter used by the Ginkgo CLI to aggregate and present parallel test output +coherently as tests complete. You shouldn't need to use this in your code. To run tests in parallel: + + ginkgo -nodes=N + +where N is the number of nodes you desire. +*/ +package remote + +import ( + "time" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/reporters/stenographer" + "github.com/onsi/ginkgo/types" +) + +type configAndSuite struct { + config config.GinkgoConfigType + summary *types.SuiteSummary +} + +type Aggregator struct { + nodeCount int + config config.DefaultReporterConfigType + stenographer stenographer.Stenographer + result chan bool + + suiteBeginnings chan configAndSuite + aggregatedSuiteBeginnings []configAndSuite + + beforeSuites chan *types.SetupSummary + aggregatedBeforeSuites []*types.SetupSummary + + afterSuites chan *types.SetupSummary + aggregatedAfterSuites []*types.SetupSummary + + specCompletions chan *types.SpecSummary + completedSpecs []*types.SpecSummary + + suiteEndings chan *types.SuiteSummary + aggregatedSuiteEndings []*types.SuiteSummary + specs []*types.SpecSummary + + startTime time.Time +} + +func NewAggregator(nodeCount int, result chan bool, config config.DefaultReporterConfigType, stenographer stenographer.Stenographer) *Aggregator { + aggregator := &Aggregator{ + nodeCount: nodeCount, + result: result, + config: config, + stenographer: stenographer, + + suiteBeginnings: make(chan configAndSuite), + beforeSuites: make(chan *types.SetupSummary), + afterSuites: make(chan *types.SetupSummary), + specCompletions: make(chan *types.SpecSummary), + suiteEndings: make(chan *types.SuiteSummary), + } + + go aggregator.mux() + + return aggregator +} + +func (aggregator *Aggregator) SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) { + aggregator.suiteBeginnings <- configAndSuite{config, summary} +} + +func (aggregator *Aggregator) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { + aggregator.beforeSuites <- setupSummary +} + +func (aggregator *Aggregator) AfterSuiteDidRun(setupSummary *types.SetupSummary) { + aggregator.afterSuites <- setupSummary +} + +func (aggregator *Aggregator) SpecWillRun(specSummary *types.SpecSummary) { + //noop +} + +func (aggregator *Aggregator) SpecDidComplete(specSummary *types.SpecSummary) { + aggregator.specCompletions <- specSummary +} + +func (aggregator *Aggregator) SpecSuiteDidEnd(summary *types.SuiteSummary) { + aggregator.suiteEndings <- summary +} + +func (aggregator *Aggregator) mux() { +loop: + for { + select { + case configAndSuite := <-aggregator.suiteBeginnings: + aggregator.registerSuiteBeginning(configAndSuite) + case setupSummary := <-aggregator.beforeSuites: + aggregator.registerBeforeSuite(setupSummary) + case setupSummary := <-aggregator.afterSuites: + aggregator.registerAfterSuite(setupSummary) + case specSummary := <-aggregator.specCompletions: + aggregator.registerSpecCompletion(specSummary) + case suite := <-aggregator.suiteEndings: + finished, passed := aggregator.registerSuiteEnding(suite) + if finished { + aggregator.result <- passed + break loop + } + } + } +} + +func (aggregator *Aggregator) registerSuiteBeginning(configAndSuite configAndSuite) { + aggregator.aggregatedSuiteBeginnings = append(aggregator.aggregatedSuiteBeginnings, configAndSuite) + + if len(aggregator.aggregatedSuiteBeginnings) == 1 { + aggregator.startTime = time.Now() + } + + if len(aggregator.aggregatedSuiteBeginnings) != aggregator.nodeCount { + return + } + + aggregator.stenographer.AnnounceSuite(configAndSuite.summary.SuiteDescription, configAndSuite.config.RandomSeed, configAndSuite.config.RandomizeAllSpecs, aggregator.config.Succinct) + + totalNumberOfSpecs := 0 + if len(aggregator.aggregatedSuiteBeginnings) > 0 { + totalNumberOfSpecs = configAndSuite.summary.NumberOfSpecsBeforeParallelization + } + + aggregator.stenographer.AnnounceTotalNumberOfSpecs(totalNumberOfSpecs, aggregator.config.Succinct) + aggregator.stenographer.AnnounceAggregatedParallelRun(aggregator.nodeCount, aggregator.config.Succinct) + aggregator.flushCompletedSpecs() +} + +func (aggregator *Aggregator) registerBeforeSuite(setupSummary *types.SetupSummary) { + aggregator.aggregatedBeforeSuites = append(aggregator.aggregatedBeforeSuites, setupSummary) + aggregator.flushCompletedSpecs() +} + +func (aggregator *Aggregator) registerAfterSuite(setupSummary *types.SetupSummary) { + aggregator.aggregatedAfterSuites = append(aggregator.aggregatedAfterSuites, setupSummary) + aggregator.flushCompletedSpecs() +} + +func (aggregator *Aggregator) registerSpecCompletion(specSummary *types.SpecSummary) { + aggregator.completedSpecs = append(aggregator.completedSpecs, specSummary) + aggregator.specs = append(aggregator.specs, specSummary) + aggregator.flushCompletedSpecs() +} + +func (aggregator *Aggregator) flushCompletedSpecs() { + if len(aggregator.aggregatedSuiteBeginnings) != aggregator.nodeCount { + return + } + + for _, setupSummary := range aggregator.aggregatedBeforeSuites { + aggregator.announceBeforeSuite(setupSummary) + } + + for _, specSummary := range aggregator.completedSpecs { + aggregator.announceSpec(specSummary) + } + + for _, setupSummary := range aggregator.aggregatedAfterSuites { + aggregator.announceAfterSuite(setupSummary) + } + + aggregator.aggregatedBeforeSuites = []*types.SetupSummary{} + aggregator.completedSpecs = []*types.SpecSummary{} + aggregator.aggregatedAfterSuites = []*types.SetupSummary{} +} + +func (aggregator *Aggregator) announceBeforeSuite(setupSummary *types.SetupSummary) { + aggregator.stenographer.AnnounceCapturedOutput(setupSummary.CapturedOutput) + if setupSummary.State != types.SpecStatePassed { + aggregator.stenographer.AnnounceBeforeSuiteFailure(setupSummary, aggregator.config.Succinct, aggregator.config.FullTrace) + } +} + +func (aggregator *Aggregator) announceAfterSuite(setupSummary *types.SetupSummary) { + aggregator.stenographer.AnnounceCapturedOutput(setupSummary.CapturedOutput) + if setupSummary.State != types.SpecStatePassed { + aggregator.stenographer.AnnounceAfterSuiteFailure(setupSummary, aggregator.config.Succinct, aggregator.config.FullTrace) + } +} + +func (aggregator *Aggregator) announceSpec(specSummary *types.SpecSummary) { + if aggregator.config.Verbose && specSummary.State != types.SpecStatePending && specSummary.State != types.SpecStateSkipped { + aggregator.stenographer.AnnounceSpecWillRun(specSummary) + } + + aggregator.stenographer.AnnounceCapturedOutput(specSummary.CapturedOutput) + + switch specSummary.State { + case types.SpecStatePassed: + if specSummary.IsMeasurement { + aggregator.stenographer.AnnounceSuccessfulMeasurement(specSummary, aggregator.config.Succinct) + } else if specSummary.RunTime.Seconds() >= aggregator.config.SlowSpecThreshold { + aggregator.stenographer.AnnounceSuccessfulSlowSpec(specSummary, aggregator.config.Succinct) + } else { + aggregator.stenographer.AnnounceSuccessfulSpec(specSummary) + } + + case types.SpecStatePending: + aggregator.stenographer.AnnouncePendingSpec(specSummary, aggregator.config.NoisyPendings && !aggregator.config.Succinct) + case types.SpecStateSkipped: + aggregator.stenographer.AnnounceSkippedSpec(specSummary, aggregator.config.Succinct || !aggregator.config.NoisySkippings, aggregator.config.FullTrace) + case types.SpecStateTimedOut: + aggregator.stenographer.AnnounceSpecTimedOut(specSummary, aggregator.config.Succinct, aggregator.config.FullTrace) + case types.SpecStatePanicked: + aggregator.stenographer.AnnounceSpecPanicked(specSummary, aggregator.config.Succinct, aggregator.config.FullTrace) + case types.SpecStateFailed: + aggregator.stenographer.AnnounceSpecFailed(specSummary, aggregator.config.Succinct, aggregator.config.FullTrace) + } +} + +func (aggregator *Aggregator) registerSuiteEnding(suite *types.SuiteSummary) (finished bool, passed bool) { + aggregator.aggregatedSuiteEndings = append(aggregator.aggregatedSuiteEndings, suite) + if len(aggregator.aggregatedSuiteEndings) < aggregator.nodeCount { + return false, false + } + + aggregatedSuiteSummary := &types.SuiteSummary{} + aggregatedSuiteSummary.SuiteSucceeded = true + + for _, suiteSummary := range aggregator.aggregatedSuiteEndings { + if !suiteSummary.SuiteSucceeded { + aggregatedSuiteSummary.SuiteSucceeded = false + } + + aggregatedSuiteSummary.NumberOfSpecsThatWillBeRun += suiteSummary.NumberOfSpecsThatWillBeRun + aggregatedSuiteSummary.NumberOfTotalSpecs += suiteSummary.NumberOfTotalSpecs + aggregatedSuiteSummary.NumberOfPassedSpecs += suiteSummary.NumberOfPassedSpecs + aggregatedSuiteSummary.NumberOfFailedSpecs += suiteSummary.NumberOfFailedSpecs + aggregatedSuiteSummary.NumberOfPendingSpecs += suiteSummary.NumberOfPendingSpecs + aggregatedSuiteSummary.NumberOfSkippedSpecs += suiteSummary.NumberOfSkippedSpecs + aggregatedSuiteSummary.NumberOfFlakedSpecs += suiteSummary.NumberOfFlakedSpecs + } + + aggregatedSuiteSummary.RunTime = time.Since(aggregator.startTime) + + aggregator.stenographer.SummarizeFailures(aggregator.specs) + aggregator.stenographer.AnnounceSpecRunCompletion(aggregatedSuiteSummary, aggregator.config.Succinct) + + return true, aggregatedSuiteSummary.SuiteSucceeded +} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/forwarding_reporter.go b/vendor/github.com/onsi/ginkgo/internal/remote/forwarding_reporter.go new file mode 100644 index 00000000..284bc62e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/remote/forwarding_reporter.go @@ -0,0 +1,147 @@ +package remote + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + + "github.com/onsi/ginkgo/internal/writer" + "github.com/onsi/ginkgo/reporters" + "github.com/onsi/ginkgo/reporters/stenographer" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/types" +) + +//An interface to net/http's client to allow the injection of fakes under test +type Poster interface { + Post(url string, bodyType string, body io.Reader) (resp *http.Response, err error) +} + +/* +The ForwardingReporter is a Ginkgo reporter that forwards information to +a Ginkgo remote server. + +When streaming parallel test output, this repoter is automatically installed by Ginkgo. + +This is accomplished by passing in the GINKGO_REMOTE_REPORTING_SERVER environment variable to `go test`, the Ginkgo test runner +detects this environment variable (which should contain the host of the server) and automatically installs a ForwardingReporter +in place of Ginkgo's DefaultReporter. +*/ + +type ForwardingReporter struct { + serverHost string + poster Poster + outputInterceptor OutputInterceptor + debugMode bool + debugFile *os.File + nestedReporter *reporters.DefaultReporter +} + +func NewForwardingReporter(config config.DefaultReporterConfigType, serverHost string, poster Poster, outputInterceptor OutputInterceptor, ginkgoWriter *writer.Writer, debugFile string) *ForwardingReporter { + reporter := &ForwardingReporter{ + serverHost: serverHost, + poster: poster, + outputInterceptor: outputInterceptor, + } + + if debugFile != "" { + var err error + reporter.debugMode = true + reporter.debugFile, err = os.Create(debugFile) + if err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } + + if !config.Verbose { + //if verbose is true then the GinkgoWriter emits to stdout. Don't _also_ redirect GinkgoWriter output as that will result in duplication. + ginkgoWriter.AndRedirectTo(reporter.debugFile) + } + outputInterceptor.StreamTo(reporter.debugFile) //This is not working + + stenographer := stenographer.New(false, true, reporter.debugFile) + config.Succinct = false + config.Verbose = true + config.FullTrace = true + reporter.nestedReporter = reporters.NewDefaultReporter(config, stenographer) + } + + return reporter +} + +func (reporter *ForwardingReporter) post(path string, data interface{}) { + encoded, _ := json.Marshal(data) + buffer := bytes.NewBuffer(encoded) + reporter.poster.Post(reporter.serverHost+path, "application/json", buffer) +} + +func (reporter *ForwardingReporter) SpecSuiteWillBegin(conf config.GinkgoConfigType, summary *types.SuiteSummary) { + data := struct { + Config config.GinkgoConfigType `json:"config"` + Summary *types.SuiteSummary `json:"suite-summary"` + }{ + conf, + summary, + } + + reporter.outputInterceptor.StartInterceptingOutput() + if reporter.debugMode { + reporter.nestedReporter.SpecSuiteWillBegin(conf, summary) + reporter.debugFile.Sync() + } + reporter.post("/SpecSuiteWillBegin", data) +} + +func (reporter *ForwardingReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { + output, _ := reporter.outputInterceptor.StopInterceptingAndReturnOutput() + reporter.outputInterceptor.StartInterceptingOutput() + setupSummary.CapturedOutput = output + if reporter.debugMode { + reporter.nestedReporter.BeforeSuiteDidRun(setupSummary) + reporter.debugFile.Sync() + } + reporter.post("/BeforeSuiteDidRun", setupSummary) +} + +func (reporter *ForwardingReporter) SpecWillRun(specSummary *types.SpecSummary) { + if reporter.debugMode { + reporter.nestedReporter.SpecWillRun(specSummary) + reporter.debugFile.Sync() + } + reporter.post("/SpecWillRun", specSummary) +} + +func (reporter *ForwardingReporter) SpecDidComplete(specSummary *types.SpecSummary) { + output, _ := reporter.outputInterceptor.StopInterceptingAndReturnOutput() + reporter.outputInterceptor.StartInterceptingOutput() + specSummary.CapturedOutput = output + if reporter.debugMode { + reporter.nestedReporter.SpecDidComplete(specSummary) + reporter.debugFile.Sync() + } + reporter.post("/SpecDidComplete", specSummary) +} + +func (reporter *ForwardingReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { + output, _ := reporter.outputInterceptor.StopInterceptingAndReturnOutput() + reporter.outputInterceptor.StartInterceptingOutput() + setupSummary.CapturedOutput = output + if reporter.debugMode { + reporter.nestedReporter.AfterSuiteDidRun(setupSummary) + reporter.debugFile.Sync() + } + reporter.post("/AfterSuiteDidRun", setupSummary) +} + +func (reporter *ForwardingReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { + reporter.outputInterceptor.StopInterceptingAndReturnOutput() + if reporter.debugMode { + reporter.nestedReporter.SpecSuiteDidEnd(summary) + reporter.debugFile.Sync() + } + reporter.post("/SpecSuiteDidEnd", summary) +} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor.go b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor.go new file mode 100644 index 00000000..5154abe8 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor.go @@ -0,0 +1,13 @@ +package remote + +import "os" + +/* +The OutputInterceptor is used by the ForwardingReporter to +intercept and capture all stdin and stderr output during a test run. +*/ +type OutputInterceptor interface { + StartInterceptingOutput() error + StopInterceptingAndReturnOutput() (string, error) + StreamTo(*os.File) +} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_unix.go b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_unix.go new file mode 100644 index 00000000..774967db --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_unix.go @@ -0,0 +1,82 @@ +// +build freebsd openbsd netbsd dragonfly darwin linux solaris + +package remote + +import ( + "errors" + "io/ioutil" + "os" + + "github.com/nxadm/tail" + "golang.org/x/sys/unix" +) + +func NewOutputInterceptor() OutputInterceptor { + return &outputInterceptor{} +} + +type outputInterceptor struct { + redirectFile *os.File + streamTarget *os.File + intercepting bool + tailer *tail.Tail + doneTailing chan bool +} + +func (interceptor *outputInterceptor) StartInterceptingOutput() error { + if interceptor.intercepting { + return errors.New("Already intercepting output!") + } + interceptor.intercepting = true + + var err error + + interceptor.redirectFile, err = ioutil.TempFile("", "ginkgo-output") + if err != nil { + return err + } + + // This might call Dup3 if the dup2 syscall is not available, e.g. on + // linux/arm64 or linux/riscv64 + unix.Dup2(int(interceptor.redirectFile.Fd()), 1) + unix.Dup2(int(interceptor.redirectFile.Fd()), 2) + + if interceptor.streamTarget != nil { + interceptor.tailer, _ = tail.TailFile(interceptor.redirectFile.Name(), tail.Config{Follow: true}) + interceptor.doneTailing = make(chan bool) + + go func() { + for line := range interceptor.tailer.Lines { + interceptor.streamTarget.Write([]byte(line.Text + "\n")) + } + close(interceptor.doneTailing) + }() + } + + return nil +} + +func (interceptor *outputInterceptor) StopInterceptingAndReturnOutput() (string, error) { + if !interceptor.intercepting { + return "", errors.New("Not intercepting output!") + } + + interceptor.redirectFile.Close() + output, err := ioutil.ReadFile(interceptor.redirectFile.Name()) + os.Remove(interceptor.redirectFile.Name()) + + interceptor.intercepting = false + + if interceptor.streamTarget != nil { + interceptor.tailer.Stop() + interceptor.tailer.Cleanup() + <-interceptor.doneTailing + interceptor.streamTarget.Sync() + } + + return string(output), err +} + +func (interceptor *outputInterceptor) StreamTo(out *os.File) { + interceptor.streamTarget = out +} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_win.go b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_win.go new file mode 100644 index 00000000..40c79033 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_win.go @@ -0,0 +1,36 @@ +// +build windows + +package remote + +import ( + "errors" + "os" +) + +func NewOutputInterceptor() OutputInterceptor { + return &outputInterceptor{} +} + +type outputInterceptor struct { + intercepting bool +} + +func (interceptor *outputInterceptor) StartInterceptingOutput() error { + if interceptor.intercepting { + return errors.New("Already intercepting output!") + } + interceptor.intercepting = true + + // not working on windows... + + return nil +} + +func (interceptor *outputInterceptor) StopInterceptingAndReturnOutput() (string, error) { + // not working on windows... + interceptor.intercepting = false + + return "", nil +} + +func (interceptor *outputInterceptor) StreamTo(*os.File) {} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/server.go b/vendor/github.com/onsi/ginkgo/internal/remote/server.go new file mode 100644 index 00000000..93e9dac0 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/remote/server.go @@ -0,0 +1,224 @@ +/* + +The remote package provides the pieces to allow Ginkgo test suites to report to remote listeners. +This is used, primarily, to enable streaming parallel test output but has, in principal, broader applications (e.g. streaming test output to a browser). + +*/ + +package remote + +import ( + "encoding/json" + "io/ioutil" + "net" + "net/http" + "sync" + + "github.com/onsi/ginkgo/internal/spec_iterator" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/reporters" + "github.com/onsi/ginkgo/types" +) + +/* +Server spins up on an automatically selected port and listens for communication from the forwarding reporter. +It then forwards that communication to attached reporters. +*/ +type Server struct { + listener net.Listener + reporters []reporters.Reporter + alives []func() bool + lock *sync.Mutex + beforeSuiteData types.RemoteBeforeSuiteData + parallelTotal int + counter int +} + +//Create a new server, automatically selecting a port +func NewServer(parallelTotal int) (*Server, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + return &Server{ + listener: listener, + lock: &sync.Mutex{}, + alives: make([]func() bool, parallelTotal), + beforeSuiteData: types.RemoteBeforeSuiteData{Data: nil, State: types.RemoteBeforeSuiteStatePending}, + parallelTotal: parallelTotal, + }, nil +} + +//Start the server. You don't need to `go s.Start()`, just `s.Start()` +func (server *Server) Start() { + httpServer := &http.Server{} + mux := http.NewServeMux() + httpServer.Handler = mux + + //streaming endpoints + mux.HandleFunc("/SpecSuiteWillBegin", server.specSuiteWillBegin) + mux.HandleFunc("/BeforeSuiteDidRun", server.beforeSuiteDidRun) + mux.HandleFunc("/AfterSuiteDidRun", server.afterSuiteDidRun) + mux.HandleFunc("/SpecWillRun", server.specWillRun) + mux.HandleFunc("/SpecDidComplete", server.specDidComplete) + mux.HandleFunc("/SpecSuiteDidEnd", server.specSuiteDidEnd) + + //synchronization endpoints + mux.HandleFunc("/BeforeSuiteState", server.handleBeforeSuiteState) + mux.HandleFunc("/RemoteAfterSuiteData", server.handleRemoteAfterSuiteData) + mux.HandleFunc("/counter", server.handleCounter) + mux.HandleFunc("/has-counter", server.handleHasCounter) //for backward compatibility + + go httpServer.Serve(server.listener) +} + +//Stop the server +func (server *Server) Close() { + server.listener.Close() +} + +//The address the server can be reached it. Pass this into the `ForwardingReporter`. +func (server *Server) Address() string { + return "http://" + server.listener.Addr().String() +} + +// +// Streaming Endpoints +// + +//The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters` +func (server *Server) readAll(request *http.Request) []byte { + defer request.Body.Close() + body, _ := ioutil.ReadAll(request.Body) + return body +} + +func (server *Server) RegisterReporters(reporters ...reporters.Reporter) { + server.reporters = reporters +} + +func (server *Server) specSuiteWillBegin(writer http.ResponseWriter, request *http.Request) { + body := server.readAll(request) + + var data struct { + Config config.GinkgoConfigType `json:"config"` + Summary *types.SuiteSummary `json:"suite-summary"` + } + + json.Unmarshal(body, &data) + + for _, reporter := range server.reporters { + reporter.SpecSuiteWillBegin(data.Config, data.Summary) + } +} + +func (server *Server) beforeSuiteDidRun(writer http.ResponseWriter, request *http.Request) { + body := server.readAll(request) + var setupSummary *types.SetupSummary + json.Unmarshal(body, &setupSummary) + + for _, reporter := range server.reporters { + reporter.BeforeSuiteDidRun(setupSummary) + } +} + +func (server *Server) afterSuiteDidRun(writer http.ResponseWriter, request *http.Request) { + body := server.readAll(request) + var setupSummary *types.SetupSummary + json.Unmarshal(body, &setupSummary) + + for _, reporter := range server.reporters { + reporter.AfterSuiteDidRun(setupSummary) + } +} + +func (server *Server) specWillRun(writer http.ResponseWriter, request *http.Request) { + body := server.readAll(request) + var specSummary *types.SpecSummary + json.Unmarshal(body, &specSummary) + + for _, reporter := range server.reporters { + reporter.SpecWillRun(specSummary) + } +} + +func (server *Server) specDidComplete(writer http.ResponseWriter, request *http.Request) { + body := server.readAll(request) + var specSummary *types.SpecSummary + json.Unmarshal(body, &specSummary) + + for _, reporter := range server.reporters { + reporter.SpecDidComplete(specSummary) + } +} + +func (server *Server) specSuiteDidEnd(writer http.ResponseWriter, request *http.Request) { + body := server.readAll(request) + var suiteSummary *types.SuiteSummary + json.Unmarshal(body, &suiteSummary) + + for _, reporter := range server.reporters { + reporter.SpecSuiteDidEnd(suiteSummary) + } +} + +// +// Synchronization Endpoints +// + +func (server *Server) RegisterAlive(node int, alive func() bool) { + server.lock.Lock() + defer server.lock.Unlock() + server.alives[node-1] = alive +} + +func (server *Server) nodeIsAlive(node int) bool { + server.lock.Lock() + defer server.lock.Unlock() + alive := server.alives[node-1] + if alive == nil { + return true + } + return alive() +} + +func (server *Server) handleBeforeSuiteState(writer http.ResponseWriter, request *http.Request) { + if request.Method == "POST" { + dec := json.NewDecoder(request.Body) + dec.Decode(&(server.beforeSuiteData)) + } else { + beforeSuiteData := server.beforeSuiteData + if beforeSuiteData.State == types.RemoteBeforeSuiteStatePending && !server.nodeIsAlive(1) { + beforeSuiteData.State = types.RemoteBeforeSuiteStateDisappeared + } + enc := json.NewEncoder(writer) + enc.Encode(beforeSuiteData) + } +} + +func (server *Server) handleRemoteAfterSuiteData(writer http.ResponseWriter, request *http.Request) { + afterSuiteData := types.RemoteAfterSuiteData{ + CanRun: true, + } + for i := 2; i <= server.parallelTotal; i++ { + afterSuiteData.CanRun = afterSuiteData.CanRun && !server.nodeIsAlive(i) + } + + enc := json.NewEncoder(writer) + enc.Encode(afterSuiteData) +} + +func (server *Server) handleCounter(writer http.ResponseWriter, request *http.Request) { + c := spec_iterator.Counter{} + server.lock.Lock() + c.Index = server.counter + server.counter++ + server.lock.Unlock() + + json.NewEncoder(writer).Encode(c) +} + +func (server *Server) handleHasCounter(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("")) +} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec/spec.go b/vendor/github.com/onsi/ginkgo/internal/spec/spec.go new file mode 100644 index 00000000..6eef40a0 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/spec/spec.go @@ -0,0 +1,247 @@ +package spec + +import ( + "fmt" + "io" + "time" + + "sync" + + "github.com/onsi/ginkgo/internal/containernode" + "github.com/onsi/ginkgo/internal/leafnodes" + "github.com/onsi/ginkgo/types" +) + +type Spec struct { + subject leafnodes.SubjectNode + focused bool + announceProgress bool + + containers []*containernode.ContainerNode + + state types.SpecState + runTime time.Duration + startTime time.Time + failure types.SpecFailure + previousFailures bool + + stateMutex *sync.Mutex +} + +func New(subject leafnodes.SubjectNode, containers []*containernode.ContainerNode, announceProgress bool) *Spec { + spec := &Spec{ + subject: subject, + containers: containers, + focused: subject.Flag() == types.FlagTypeFocused, + announceProgress: announceProgress, + stateMutex: &sync.Mutex{}, + } + + spec.processFlag(subject.Flag()) + for i := len(containers) - 1; i >= 0; i-- { + spec.processFlag(containers[i].Flag()) + } + + return spec +} + +func (spec *Spec) processFlag(flag types.FlagType) { + if flag == types.FlagTypeFocused { + spec.focused = true + } else if flag == types.FlagTypePending { + spec.setState(types.SpecStatePending) + } +} + +func (spec *Spec) Skip() { + spec.setState(types.SpecStateSkipped) +} + +func (spec *Spec) Failed() bool { + return spec.getState() == types.SpecStateFailed || spec.getState() == types.SpecStatePanicked || spec.getState() == types.SpecStateTimedOut +} + +func (spec *Spec) Passed() bool { + return spec.getState() == types.SpecStatePassed +} + +func (spec *Spec) Flaked() bool { + return spec.getState() == types.SpecStatePassed && spec.previousFailures +} + +func (spec *Spec) Pending() bool { + return spec.getState() == types.SpecStatePending +} + +func (spec *Spec) Skipped() bool { + return spec.getState() == types.SpecStateSkipped +} + +func (spec *Spec) Focused() bool { + return spec.focused +} + +func (spec *Spec) IsMeasurement() bool { + return spec.subject.Type() == types.SpecComponentTypeMeasure +} + +func (spec *Spec) Summary(suiteID string) *types.SpecSummary { + componentTexts := make([]string, len(spec.containers)+1) + componentCodeLocations := make([]types.CodeLocation, len(spec.containers)+1) + + for i, container := range spec.containers { + componentTexts[i] = container.Text() + componentCodeLocations[i] = container.CodeLocation() + } + + componentTexts[len(spec.containers)] = spec.subject.Text() + componentCodeLocations[len(spec.containers)] = spec.subject.CodeLocation() + + runTime := spec.runTime + if runTime == 0 && !spec.startTime.IsZero() { + runTime = time.Since(spec.startTime) + } + + return &types.SpecSummary{ + IsMeasurement: spec.IsMeasurement(), + NumberOfSamples: spec.subject.Samples(), + ComponentTexts: componentTexts, + ComponentCodeLocations: componentCodeLocations, + State: spec.getState(), + RunTime: runTime, + Failure: spec.failure, + Measurements: spec.measurementsReport(), + SuiteID: suiteID, + } +} + +func (spec *Spec) ConcatenatedString() string { + s := "" + for _, container := range spec.containers { + s += container.Text() + " " + } + + return s + spec.subject.Text() +} + +func (spec *Spec) Run(writer io.Writer) { + if spec.getState() == types.SpecStateFailed { + spec.previousFailures = true + } + + spec.startTime = time.Now() + defer func() { + spec.runTime = time.Since(spec.startTime) + }() + + for sample := 0; sample < spec.subject.Samples(); sample++ { + spec.runSample(sample, writer) + + if spec.getState() != types.SpecStatePassed { + return + } + } +} + +func (spec *Spec) getState() types.SpecState { + spec.stateMutex.Lock() + defer spec.stateMutex.Unlock() + return spec.state +} + +func (spec *Spec) setState(state types.SpecState) { + spec.stateMutex.Lock() + defer spec.stateMutex.Unlock() + spec.state = state +} + +func (spec *Spec) runSample(sample int, writer io.Writer) { + spec.setState(types.SpecStatePassed) + spec.failure = types.SpecFailure{} + innerMostContainerIndexToUnwind := -1 + + defer func() { + for i := innerMostContainerIndexToUnwind; i >= 0; i-- { + container := spec.containers[i] + for _, justAfterEach := range container.SetupNodesOfType(types.SpecComponentTypeJustAfterEach) { + spec.announceSetupNode(writer, "JustAfterEach", container, justAfterEach) + justAfterEachState, justAfterEachFailure := justAfterEach.Run() + if justAfterEachState != types.SpecStatePassed && spec.state == types.SpecStatePassed { + spec.state = justAfterEachState + spec.failure = justAfterEachFailure + } + } + } + + for i := innerMostContainerIndexToUnwind; i >= 0; i-- { + container := spec.containers[i] + for _, afterEach := range container.SetupNodesOfType(types.SpecComponentTypeAfterEach) { + spec.announceSetupNode(writer, "AfterEach", container, afterEach) + afterEachState, afterEachFailure := afterEach.Run() + if afterEachState != types.SpecStatePassed && spec.getState() == types.SpecStatePassed { + spec.setState(afterEachState) + spec.failure = afterEachFailure + } + } + } + }() + + for i, container := range spec.containers { + innerMostContainerIndexToUnwind = i + for _, beforeEach := range container.SetupNodesOfType(types.SpecComponentTypeBeforeEach) { + spec.announceSetupNode(writer, "BeforeEach", container, beforeEach) + s, f := beforeEach.Run() + spec.failure = f + spec.setState(s) + if spec.getState() != types.SpecStatePassed { + return + } + } + } + + for _, container := range spec.containers { + for _, justBeforeEach := range container.SetupNodesOfType(types.SpecComponentTypeJustBeforeEach) { + spec.announceSetupNode(writer, "JustBeforeEach", container, justBeforeEach) + s, f := justBeforeEach.Run() + spec.failure = f + spec.setState(s) + if spec.getState() != types.SpecStatePassed { + return + } + } + } + + spec.announceSubject(writer, spec.subject) + s, f := spec.subject.Run() + spec.failure = f + spec.setState(s) +} + +func (spec *Spec) announceSetupNode(writer io.Writer, nodeType string, container *containernode.ContainerNode, setupNode leafnodes.BasicNode) { + if spec.announceProgress { + s := fmt.Sprintf("[%s] %s\n %s\n", nodeType, container.Text(), setupNode.CodeLocation().String()) + writer.Write([]byte(s)) + } +} + +func (spec *Spec) announceSubject(writer io.Writer, subject leafnodes.SubjectNode) { + if spec.announceProgress { + nodeType := "" + switch subject.Type() { + case types.SpecComponentTypeIt: + nodeType = "It" + case types.SpecComponentTypeMeasure: + nodeType = "Measure" + } + s := fmt.Sprintf("[%s] %s\n %s\n", nodeType, subject.Text(), subject.CodeLocation().String()) + writer.Write([]byte(s)) + } +} + +func (spec *Spec) measurementsReport() map[string]*types.SpecMeasurement { + if !spec.IsMeasurement() || spec.Failed() { + return map[string]*types.SpecMeasurement{} + } + + return spec.subject.(*leafnodes.MeasureNode).MeasurementsReport() +} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec/specs.go b/vendor/github.com/onsi/ginkgo/internal/spec/specs.go new file mode 100644 index 00000000..0a24139f --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/spec/specs.go @@ -0,0 +1,144 @@ +package spec + +import ( + "math/rand" + "regexp" + "sort" + "strings" +) + +type Specs struct { + specs []*Spec + names []string + + hasProgrammaticFocus bool + RegexScansFilePath bool +} + +func NewSpecs(specs []*Spec) *Specs { + names := make([]string, len(specs)) + for i, spec := range specs { + names[i] = spec.ConcatenatedString() + } + return &Specs{ + specs: specs, + names: names, + } +} + +func (e *Specs) Specs() []*Spec { + return e.specs +} + +func (e *Specs) HasProgrammaticFocus() bool { + return e.hasProgrammaticFocus +} + +func (e *Specs) Shuffle(r *rand.Rand) { + sort.Sort(e) + permutation := r.Perm(len(e.specs)) + shuffledSpecs := make([]*Spec, len(e.specs)) + names := make([]string, len(e.specs)) + for i, j := range permutation { + shuffledSpecs[i] = e.specs[j] + names[i] = e.names[j] + } + e.specs = shuffledSpecs + e.names = names +} + +func (e *Specs) ApplyFocus(description string, focus, skip []string) { + if len(focus)+len(skip) == 0 { + e.applyProgrammaticFocus() + } else { + e.applyRegExpFocusAndSkip(description, focus, skip) + } +} + +func (e *Specs) applyProgrammaticFocus() { + e.hasProgrammaticFocus = false + for _, spec := range e.specs { + if spec.Focused() && !spec.Pending() { + e.hasProgrammaticFocus = true + break + } + } + + if e.hasProgrammaticFocus { + for _, spec := range e.specs { + if !spec.Focused() { + spec.Skip() + } + } + } +} + +// toMatch returns a byte[] to be used by regex matchers. When adding new behaviours to the matching function, +// this is the place which we append to. +func (e *Specs) toMatch(description string, i int) []byte { + if i > len(e.names) { + return nil + } + if e.RegexScansFilePath { + return []byte( + description + " " + + e.names[i] + " " + + e.specs[i].subject.CodeLocation().FileName) + } else { + return []byte( + description + " " + + e.names[i]) + } +} + +func (e *Specs) applyRegExpFocusAndSkip(description string, focus, skip []string) { + var focusFilter, skipFilter *regexp.Regexp + if len(focus) > 0 { + focusFilter = regexp.MustCompile(strings.Join(focus, "|")) + } + if len(skip) > 0 { + skipFilter = regexp.MustCompile(strings.Join(skip, "|")) + } + + for i, spec := range e.specs { + matchesFocus := true + matchesSkip := false + + toMatch := e.toMatch(description, i) + + if focusFilter != nil { + matchesFocus = focusFilter.Match(toMatch) + } + + if skipFilter != nil { + matchesSkip = skipFilter.Match(toMatch) + } + + if !matchesFocus || matchesSkip { + spec.Skip() + } + } +} + +func (e *Specs) SkipMeasurements() { + for _, spec := range e.specs { + if spec.IsMeasurement() { + spec.Skip() + } + } +} + +//sort.Interface + +func (e *Specs) Len() int { + return len(e.specs) +} + +func (e *Specs) Less(i, j int) bool { + return e.names[i] < e.names[j] +} + +func (e *Specs) Swap(i, j int) { + e.names[i], e.names[j] = e.names[j], e.names[i] + e.specs[i], e.specs[j] = e.specs[j], e.specs[i] +} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/index_computer.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/index_computer.go new file mode 100644 index 00000000..82272554 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/index_computer.go @@ -0,0 +1,55 @@ +package spec_iterator + +func ParallelizedIndexRange(length int, parallelTotal int, parallelNode int) (startIndex int, count int) { + if length == 0 { + return 0, 0 + } + + // We have more nodes than tests. Trivial case. + if parallelTotal >= length { + if parallelNode > length { + return 0, 0 + } else { + return parallelNode - 1, 1 + } + } + + // This is the minimum amount of tests that a node will be required to run + minTestsPerNode := length / parallelTotal + + // This is the maximum amount of tests that a node will be required to run + // The algorithm guarantees that this would be equal to at least the minimum amount + // and at most one more + maxTestsPerNode := minTestsPerNode + if length%parallelTotal != 0 { + maxTestsPerNode++ + } + + // Number of nodes that will have to run the maximum amount of tests per node + numMaxLoadNodes := length % parallelTotal + + // Number of nodes that precede the current node and will have to run the maximum amount of tests per node + var numPrecedingMaxLoadNodes int + if parallelNode > numMaxLoadNodes { + numPrecedingMaxLoadNodes = numMaxLoadNodes + } else { + numPrecedingMaxLoadNodes = parallelNode - 1 + } + + // Number of nodes that precede the current node and will have to run the minimum amount of tests per node + var numPrecedingMinLoadNodes int + if parallelNode <= numMaxLoadNodes { + numPrecedingMinLoadNodes = 0 + } else { + numPrecedingMinLoadNodes = parallelNode - numMaxLoadNodes - 1 + } + + // Evaluate the test start index and number of tests to run + startIndex = numPrecedingMaxLoadNodes*maxTestsPerNode + numPrecedingMinLoadNodes*minTestsPerNode + if parallelNode > numMaxLoadNodes { + count = minTestsPerNode + } else { + count = maxTestsPerNode + } + return +} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/parallel_spec_iterator.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/parallel_spec_iterator.go new file mode 100644 index 00000000..99f548bc --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/parallel_spec_iterator.go @@ -0,0 +1,59 @@ +package spec_iterator + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/onsi/ginkgo/internal/spec" +) + +type ParallelIterator struct { + specs []*spec.Spec + host string + client *http.Client +} + +func NewParallelIterator(specs []*spec.Spec, host string) *ParallelIterator { + return &ParallelIterator{ + specs: specs, + host: host, + client: &http.Client{}, + } +} + +func (s *ParallelIterator) Next() (*spec.Spec, error) { + resp, err := s.client.Get(s.host + "/counter") + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + + var counter Counter + err = json.NewDecoder(resp.Body).Decode(&counter) + if err != nil { + return nil, err + } + + if counter.Index >= len(s.specs) { + return nil, ErrClosed + } + + return s.specs[counter.Index], nil +} + +func (s *ParallelIterator) NumberOfSpecsPriorToIteration() int { + return len(s.specs) +} + +func (s *ParallelIterator) NumberOfSpecsToProcessIfKnown() (int, bool) { + return -1, false +} + +func (s *ParallelIterator) NumberOfSpecsThatWillBeRunIfKnown() (int, bool) { + return -1, false +} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/serial_spec_iterator.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/serial_spec_iterator.go new file mode 100644 index 00000000..a51c93b8 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/serial_spec_iterator.go @@ -0,0 +1,45 @@ +package spec_iterator + +import ( + "github.com/onsi/ginkgo/internal/spec" +) + +type SerialIterator struct { + specs []*spec.Spec + index int +} + +func NewSerialIterator(specs []*spec.Spec) *SerialIterator { + return &SerialIterator{ + specs: specs, + index: 0, + } +} + +func (s *SerialIterator) Next() (*spec.Spec, error) { + if s.index >= len(s.specs) { + return nil, ErrClosed + } + + spec := s.specs[s.index] + s.index += 1 + return spec, nil +} + +func (s *SerialIterator) NumberOfSpecsPriorToIteration() int { + return len(s.specs) +} + +func (s *SerialIterator) NumberOfSpecsToProcessIfKnown() (int, bool) { + return len(s.specs), true +} + +func (s *SerialIterator) NumberOfSpecsThatWillBeRunIfKnown() (int, bool) { + count := 0 + for _, s := range s.specs { + if !s.Skipped() && !s.Pending() { + count += 1 + } + } + return count, true +} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/sharded_parallel_spec_iterator.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/sharded_parallel_spec_iterator.go new file mode 100644 index 00000000..ad4a3ea3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/sharded_parallel_spec_iterator.go @@ -0,0 +1,47 @@ +package spec_iterator + +import "github.com/onsi/ginkgo/internal/spec" + +type ShardedParallelIterator struct { + specs []*spec.Spec + index int + maxIndex int +} + +func NewShardedParallelIterator(specs []*spec.Spec, total int, node int) *ShardedParallelIterator { + startIndex, count := ParallelizedIndexRange(len(specs), total, node) + + return &ShardedParallelIterator{ + specs: specs, + index: startIndex, + maxIndex: startIndex + count, + } +} + +func (s *ShardedParallelIterator) Next() (*spec.Spec, error) { + if s.index >= s.maxIndex { + return nil, ErrClosed + } + + spec := s.specs[s.index] + s.index += 1 + return spec, nil +} + +func (s *ShardedParallelIterator) NumberOfSpecsPriorToIteration() int { + return len(s.specs) +} + +func (s *ShardedParallelIterator) NumberOfSpecsToProcessIfKnown() (int, bool) { + return s.maxIndex - s.index, true +} + +func (s *ShardedParallelIterator) NumberOfSpecsThatWillBeRunIfKnown() (int, bool) { + count := 0 + for i := s.index; i < s.maxIndex; i += 1 { + if !s.specs[i].Skipped() && !s.specs[i].Pending() { + count += 1 + } + } + return count, true +} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/spec_iterator.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/spec_iterator.go new file mode 100644 index 00000000..74bffad6 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/spec_iterator.go @@ -0,0 +1,20 @@ +package spec_iterator + +import ( + "errors" + + "github.com/onsi/ginkgo/internal/spec" +) + +var ErrClosed = errors.New("no more specs to run") + +type SpecIterator interface { + Next() (*spec.Spec, error) + NumberOfSpecsPriorToIteration() int + NumberOfSpecsToProcessIfKnown() (int, bool) + NumberOfSpecsThatWillBeRunIfKnown() (int, bool) +} + +type Counter struct { + Index int `json:"index"` +} diff --git a/vendor/github.com/onsi/ginkgo/internal/writer/fake_writer.go b/vendor/github.com/onsi/ginkgo/internal/writer/fake_writer.go new file mode 100644 index 00000000..6739c3f6 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/writer/fake_writer.go @@ -0,0 +1,36 @@ +package writer + +type FakeGinkgoWriter struct { + EventStream []string +} + +func NewFake() *FakeGinkgoWriter { + return &FakeGinkgoWriter{ + EventStream: []string{}, + } +} + +func (writer *FakeGinkgoWriter) AddEvent(event string) { + writer.EventStream = append(writer.EventStream, event) +} + +func (writer *FakeGinkgoWriter) Truncate() { + writer.EventStream = append(writer.EventStream, "TRUNCATE") +} + +func (writer *FakeGinkgoWriter) DumpOut() { + writer.EventStream = append(writer.EventStream, "DUMP") +} + +func (writer *FakeGinkgoWriter) DumpOutWithHeader(header string) { + writer.EventStream = append(writer.EventStream, "DUMP_WITH_HEADER: "+header) +} + +func (writer *FakeGinkgoWriter) Bytes() []byte { + writer.EventStream = append(writer.EventStream, "BYTES") + return nil +} + +func (writer *FakeGinkgoWriter) Write(data []byte) (n int, err error) { + return 0, nil +} diff --git a/vendor/github.com/onsi/ginkgo/internal/writer/writer.go b/vendor/github.com/onsi/ginkgo/internal/writer/writer.go new file mode 100644 index 00000000..98eca3bd --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/internal/writer/writer.go @@ -0,0 +1,89 @@ +package writer + +import ( + "bytes" + "io" + "sync" +) + +type WriterInterface interface { + io.Writer + + Truncate() + DumpOut() + DumpOutWithHeader(header string) + Bytes() []byte +} + +type Writer struct { + buffer *bytes.Buffer + outWriter io.Writer + lock *sync.Mutex + stream bool + redirector io.Writer +} + +func New(outWriter io.Writer) *Writer { + return &Writer{ + buffer: &bytes.Buffer{}, + lock: &sync.Mutex{}, + outWriter: outWriter, + stream: true, + } +} + +func (w *Writer) AndRedirectTo(writer io.Writer) { + w.redirector = writer +} + +func (w *Writer) SetStream(stream bool) { + w.lock.Lock() + defer w.lock.Unlock() + w.stream = stream +} + +func (w *Writer) Write(b []byte) (n int, err error) { + w.lock.Lock() + defer w.lock.Unlock() + + n, err = w.buffer.Write(b) + if w.redirector != nil { + w.redirector.Write(b) + } + if w.stream { + return w.outWriter.Write(b) + } + return n, err +} + +func (w *Writer) Truncate() { + w.lock.Lock() + defer w.lock.Unlock() + w.buffer.Reset() +} + +func (w *Writer) DumpOut() { + w.lock.Lock() + defer w.lock.Unlock() + if !w.stream { + w.buffer.WriteTo(w.outWriter) + } +} + +func (w *Writer) Bytes() []byte { + w.lock.Lock() + defer w.lock.Unlock() + b := w.buffer.Bytes() + copied := make([]byte, len(b)) + copy(copied, b) + return copied +} + +func (w *Writer) DumpOutWithHeader(header string) { + w.lock.Lock() + defer w.lock.Unlock() + if !w.stream && w.buffer.Len() > 0 { + w.outWriter.Write([]byte(header)) + w.buffer.WriteTo(w.outWriter) + } +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/default_reporter.go b/vendor/github.com/onsi/ginkgo/reporters/default_reporter.go new file mode 100644 index 00000000..f0c9f614 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/default_reporter.go @@ -0,0 +1,87 @@ +/* +Ginkgo's Default Reporter + +A number of command line flags are available to tweak Ginkgo's default output. + +These are documented [here](http://onsi.github.io/ginkgo/#running_tests) +*/ +package reporters + +import ( + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/reporters/stenographer" + "github.com/onsi/ginkgo/types" +) + +type DefaultReporter struct { + config config.DefaultReporterConfigType + stenographer stenographer.Stenographer + specSummaries []*types.SpecSummary +} + +func NewDefaultReporter(config config.DefaultReporterConfigType, stenographer stenographer.Stenographer) *DefaultReporter { + return &DefaultReporter{ + config: config, + stenographer: stenographer, + } +} + +func (reporter *DefaultReporter) SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) { + reporter.stenographer.AnnounceSuite(summary.SuiteDescription, config.RandomSeed, config.RandomizeAllSpecs, reporter.config.Succinct) + if config.ParallelTotal > 1 { + reporter.stenographer.AnnounceParallelRun(config.ParallelNode, config.ParallelTotal, reporter.config.Succinct) + } else { + reporter.stenographer.AnnounceNumberOfSpecs(summary.NumberOfSpecsThatWillBeRun, summary.NumberOfTotalSpecs, reporter.config.Succinct) + } +} + +func (reporter *DefaultReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { + if setupSummary.State != types.SpecStatePassed { + reporter.stenographer.AnnounceBeforeSuiteFailure(setupSummary, reporter.config.Succinct, reporter.config.FullTrace) + } +} + +func (reporter *DefaultReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { + if setupSummary.State != types.SpecStatePassed { + reporter.stenographer.AnnounceAfterSuiteFailure(setupSummary, reporter.config.Succinct, reporter.config.FullTrace) + } +} + +func (reporter *DefaultReporter) SpecWillRun(specSummary *types.SpecSummary) { + if reporter.config.Verbose && !reporter.config.Succinct && specSummary.State != types.SpecStatePending && specSummary.State != types.SpecStateSkipped { + reporter.stenographer.AnnounceSpecWillRun(specSummary) + } +} + +func (reporter *DefaultReporter) SpecDidComplete(specSummary *types.SpecSummary) { + switch specSummary.State { + case types.SpecStatePassed: + if specSummary.IsMeasurement { + reporter.stenographer.AnnounceSuccessfulMeasurement(specSummary, reporter.config.Succinct) + } else if specSummary.RunTime.Seconds() >= reporter.config.SlowSpecThreshold { + reporter.stenographer.AnnounceSuccessfulSlowSpec(specSummary, reporter.config.Succinct) + } else { + reporter.stenographer.AnnounceSuccessfulSpec(specSummary) + if reporter.config.ReportPassed { + reporter.stenographer.AnnounceCapturedOutput(specSummary.CapturedOutput) + } + } + case types.SpecStatePending: + reporter.stenographer.AnnouncePendingSpec(specSummary, reporter.config.NoisyPendings && !reporter.config.Succinct) + case types.SpecStateSkipped: + reporter.stenographer.AnnounceSkippedSpec(specSummary, reporter.config.Succinct || !reporter.config.NoisySkippings, reporter.config.FullTrace) + case types.SpecStateTimedOut: + reporter.stenographer.AnnounceSpecTimedOut(specSummary, reporter.config.Succinct, reporter.config.FullTrace) + case types.SpecStatePanicked: + reporter.stenographer.AnnounceSpecPanicked(specSummary, reporter.config.Succinct, reporter.config.FullTrace) + case types.SpecStateFailed: + reporter.stenographer.AnnounceSpecFailed(specSummary, reporter.config.Succinct, reporter.config.FullTrace) + } + + reporter.specSummaries = append(reporter.specSummaries, specSummary) +} + +func (reporter *DefaultReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { + reporter.stenographer.SummarizeFailures(reporter.specSummaries) + reporter.stenographer.AnnounceSpecRunCompletion(summary, reporter.config.Succinct) +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/fake_reporter.go b/vendor/github.com/onsi/ginkgo/reporters/fake_reporter.go new file mode 100644 index 00000000..27db4794 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/fake_reporter.go @@ -0,0 +1,59 @@ +package reporters + +import ( + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/types" +) + +//FakeReporter is useful for testing purposes +type FakeReporter struct { + Config config.GinkgoConfigType + + BeginSummary *types.SuiteSummary + BeforeSuiteSummary *types.SetupSummary + SpecWillRunSummaries []*types.SpecSummary + SpecSummaries []*types.SpecSummary + AfterSuiteSummary *types.SetupSummary + EndSummary *types.SuiteSummary + + SpecWillRunStub func(specSummary *types.SpecSummary) + SpecDidCompleteStub func(specSummary *types.SpecSummary) +} + +func NewFakeReporter() *FakeReporter { + return &FakeReporter{ + SpecWillRunSummaries: make([]*types.SpecSummary, 0), + SpecSummaries: make([]*types.SpecSummary, 0), + } +} + +func (fakeR *FakeReporter) SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) { + fakeR.Config = config + fakeR.BeginSummary = summary +} + +func (fakeR *FakeReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { + fakeR.BeforeSuiteSummary = setupSummary +} + +func (fakeR *FakeReporter) SpecWillRun(specSummary *types.SpecSummary) { + if fakeR.SpecWillRunStub != nil { + fakeR.SpecWillRunStub(specSummary) + } + fakeR.SpecWillRunSummaries = append(fakeR.SpecWillRunSummaries, specSummary) +} + +func (fakeR *FakeReporter) SpecDidComplete(specSummary *types.SpecSummary) { + if fakeR.SpecDidCompleteStub != nil { + fakeR.SpecDidCompleteStub(specSummary) + } + fakeR.SpecSummaries = append(fakeR.SpecSummaries, specSummary) +} + +func (fakeR *FakeReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { + fakeR.AfterSuiteSummary = setupSummary +} + +func (fakeR *FakeReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { + fakeR.EndSummary = summary +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/junit_reporter.go b/vendor/github.com/onsi/ginkgo/reporters/junit_reporter.go new file mode 100644 index 00000000..01ddca6e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/junit_reporter.go @@ -0,0 +1,178 @@ +/* + +JUnit XML Reporter for Ginkgo + +For usage instructions: http://onsi.github.io/ginkgo/#generating_junit_xml_output + +*/ + +package reporters + +import ( + "encoding/xml" + "fmt" + "math" + "os" + "path/filepath" + "strings" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/types" +) + +type JUnitTestSuite struct { + XMLName xml.Name `xml:"testsuite"` + TestCases []JUnitTestCase `xml:"testcase"` + Name string `xml:"name,attr"` + Tests int `xml:"tests,attr"` + Failures int `xml:"failures,attr"` + Errors int `xml:"errors,attr"` + Time float64 `xml:"time,attr"` +} + +type JUnitTestCase struct { + Name string `xml:"name,attr"` + ClassName string `xml:"classname,attr"` + FailureMessage *JUnitFailureMessage `xml:"failure,omitempty"` + Skipped *JUnitSkipped `xml:"skipped,omitempty"` + Time float64 `xml:"time,attr"` + SystemOut string `xml:"system-out,omitempty"` +} + +type JUnitFailureMessage struct { + Type string `xml:"type,attr"` + Message string `xml:",chardata"` +} + +type JUnitSkipped struct { + Message string `xml:",chardata"` +} + +type JUnitReporter struct { + suite JUnitTestSuite + filename string + testSuiteName string + ReporterConfig config.DefaultReporterConfigType +} + +//NewJUnitReporter creates a new JUnit XML reporter. The XML will be stored in the passed in filename. +func NewJUnitReporter(filename string) *JUnitReporter { + return &JUnitReporter{ + filename: filename, + } +} + +func (reporter *JUnitReporter) SpecSuiteWillBegin(ginkgoConfig config.GinkgoConfigType, summary *types.SuiteSummary) { + reporter.suite = JUnitTestSuite{ + Name: summary.SuiteDescription, + TestCases: []JUnitTestCase{}, + } + reporter.testSuiteName = summary.SuiteDescription + reporter.ReporterConfig = config.DefaultReporterConfig +} + +func (reporter *JUnitReporter) SpecWillRun(specSummary *types.SpecSummary) { +} + +func (reporter *JUnitReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { + reporter.handleSetupSummary("BeforeSuite", setupSummary) +} + +func (reporter *JUnitReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { + reporter.handleSetupSummary("AfterSuite", setupSummary) +} + +func failureMessage(failure types.SpecFailure) string { + return fmt.Sprintf("%s\n%s\n%s", failure.ComponentCodeLocation.String(), failure.Message, failure.Location.String()) +} + +func (reporter *JUnitReporter) handleSetupSummary(name string, setupSummary *types.SetupSummary) { + if setupSummary.State != types.SpecStatePassed { + testCase := JUnitTestCase{ + Name: name, + ClassName: reporter.testSuiteName, + } + + testCase.FailureMessage = &JUnitFailureMessage{ + Type: reporter.failureTypeForState(setupSummary.State), + Message: failureMessage(setupSummary.Failure), + } + testCase.SystemOut = setupSummary.CapturedOutput + testCase.Time = setupSummary.RunTime.Seconds() + reporter.suite.TestCases = append(reporter.suite.TestCases, testCase) + } +} + +func (reporter *JUnitReporter) SpecDidComplete(specSummary *types.SpecSummary) { + testCase := JUnitTestCase{ + Name: strings.Join(specSummary.ComponentTexts[1:], " "), + ClassName: reporter.testSuiteName, + } + if reporter.ReporterConfig.ReportPassed && specSummary.State == types.SpecStatePassed { + testCase.SystemOut = specSummary.CapturedOutput + } + if specSummary.State == types.SpecStateFailed || specSummary.State == types.SpecStateTimedOut || specSummary.State == types.SpecStatePanicked { + testCase.FailureMessage = &JUnitFailureMessage{ + Type: reporter.failureTypeForState(specSummary.State), + Message: failureMessage(specSummary.Failure), + } + if specSummary.State == types.SpecStatePanicked { + testCase.FailureMessage.Message += fmt.Sprintf("\n\nPanic: %s\n\nFull stack:\n%s", + specSummary.Failure.ForwardedPanic, + specSummary.Failure.Location.FullStackTrace) + } + testCase.SystemOut = specSummary.CapturedOutput + } + if specSummary.State == types.SpecStateSkipped || specSummary.State == types.SpecStatePending { + testCase.Skipped = &JUnitSkipped{} + if specSummary.Failure.Message != "" { + testCase.Skipped.Message = failureMessage(specSummary.Failure) + } + } + testCase.Time = specSummary.RunTime.Seconds() + reporter.suite.TestCases = append(reporter.suite.TestCases, testCase) +} + +func (reporter *JUnitReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { + reporter.suite.Tests = summary.NumberOfSpecsThatWillBeRun + reporter.suite.Time = math.Trunc(summary.RunTime.Seconds()*1000) / 1000 + reporter.suite.Failures = summary.NumberOfFailedSpecs + reporter.suite.Errors = 0 + if reporter.ReporterConfig.ReportFile != "" { + reporter.filename = reporter.ReporterConfig.ReportFile + fmt.Printf("\nJUnit path was configured: %s\n", reporter.filename) + } + filePath, _ := filepath.Abs(reporter.filename) + dirPath := filepath.Dir(filePath) + err := os.MkdirAll(dirPath, os.ModePerm) + if err != nil { + fmt.Printf("\nFailed to create JUnit directory: %s\n\t%s", filePath, err.Error()) + } + file, err := os.Create(filePath) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to create JUnit report file: %s\n\t%s", filePath, err.Error()) + } + defer file.Close() + file.WriteString(xml.Header) + encoder := xml.NewEncoder(file) + encoder.Indent(" ", " ") + err = encoder.Encode(reporter.suite) + if err == nil { + fmt.Fprintf(os.Stdout, "\nJUnit report was created: %s\n", filePath) + } else { + fmt.Fprintf(os.Stderr,"\nFailed to generate JUnit report data:\n\t%s", err.Error()) + } +} + +func (reporter *JUnitReporter) failureTypeForState(state types.SpecState) string { + switch state { + case types.SpecStateFailed: + return "Failure" + case types.SpecStateTimedOut: + return "Timeout" + case types.SpecStatePanicked: + return "Panic" + default: + return "" + } +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/reporter.go b/vendor/github.com/onsi/ginkgo/reporters/reporter.go new file mode 100644 index 00000000..348b9dfc --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/reporter.go @@ -0,0 +1,15 @@ +package reporters + +import ( + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/types" +) + +type Reporter interface { + SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) + BeforeSuiteDidRun(setupSummary *types.SetupSummary) + SpecWillRun(specSummary *types.SpecSummary) + SpecDidComplete(specSummary *types.SpecSummary) + AfterSuiteDidRun(setupSummary *types.SetupSummary) + SpecSuiteDidEnd(summary *types.SuiteSummary) +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/console_logging.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/console_logging.go new file mode 100644 index 00000000..45b8f886 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/console_logging.go @@ -0,0 +1,64 @@ +package stenographer + +import ( + "fmt" + "strings" +) + +func (s *consoleStenographer) colorize(colorCode string, format string, args ...interface{}) string { + var out string + + if len(args) > 0 { + out = fmt.Sprintf(format, args...) + } else { + out = format + } + + if s.color { + return fmt.Sprintf("%s%s%s", colorCode, out, defaultStyle) + } else { + return out + } +} + +func (s *consoleStenographer) printBanner(text string, bannerCharacter string) { + fmt.Fprintln(s.w, text) + fmt.Fprintln(s.w, strings.Repeat(bannerCharacter, len(text))) +} + +func (s *consoleStenographer) printNewLine() { + fmt.Fprintln(s.w, "") +} + +func (s *consoleStenographer) printDelimiter() { + fmt.Fprintln(s.w, s.colorize(grayColor, "%s", strings.Repeat("-", 30))) +} + +func (s *consoleStenographer) print(indentation int, format string, args ...interface{}) { + fmt.Fprint(s.w, s.indent(indentation, format, args...)) +} + +func (s *consoleStenographer) println(indentation int, format string, args ...interface{}) { + fmt.Fprintln(s.w, s.indent(indentation, format, args...)) +} + +func (s *consoleStenographer) indent(indentation int, format string, args ...interface{}) string { + var text string + + if len(args) > 0 { + text = fmt.Sprintf(format, args...) + } else { + text = format + } + + stringArray := strings.Split(text, "\n") + padding := "" + if indentation >= 0 { + padding = strings.Repeat(" ", indentation) + } + for i, s := range stringArray { + stringArray[i] = fmt.Sprintf("%s%s", padding, s) + } + + return strings.Join(stringArray, "\n") +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/fake_stenographer.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/fake_stenographer.go new file mode 100644 index 00000000..1aa5b9db --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/fake_stenographer.go @@ -0,0 +1,142 @@ +package stenographer + +import ( + "sync" + + "github.com/onsi/ginkgo/types" +) + +func NewFakeStenographerCall(method string, args ...interface{}) FakeStenographerCall { + return FakeStenographerCall{ + Method: method, + Args: args, + } +} + +type FakeStenographer struct { + calls []FakeStenographerCall + lock *sync.Mutex +} + +type FakeStenographerCall struct { + Method string + Args []interface{} +} + +func NewFakeStenographer() *FakeStenographer { + stenographer := &FakeStenographer{ + lock: &sync.Mutex{}, + } + stenographer.Reset() + return stenographer +} + +func (stenographer *FakeStenographer) Calls() []FakeStenographerCall { + stenographer.lock.Lock() + defer stenographer.lock.Unlock() + + return stenographer.calls +} + +func (stenographer *FakeStenographer) Reset() { + stenographer.lock.Lock() + defer stenographer.lock.Unlock() + + stenographer.calls = make([]FakeStenographerCall, 0) +} + +func (stenographer *FakeStenographer) CallsTo(method string) []FakeStenographerCall { + stenographer.lock.Lock() + defer stenographer.lock.Unlock() + + results := make([]FakeStenographerCall, 0) + for _, call := range stenographer.calls { + if call.Method == method { + results = append(results, call) + } + } + + return results +} + +func (stenographer *FakeStenographer) registerCall(method string, args ...interface{}) { + stenographer.lock.Lock() + defer stenographer.lock.Unlock() + + stenographer.calls = append(stenographer.calls, NewFakeStenographerCall(method, args...)) +} + +func (stenographer *FakeStenographer) AnnounceSuite(description string, randomSeed int64, randomizingAll bool, succinct bool) { + stenographer.registerCall("AnnounceSuite", description, randomSeed, randomizingAll, succinct) +} + +func (stenographer *FakeStenographer) AnnounceAggregatedParallelRun(nodes int, succinct bool) { + stenographer.registerCall("AnnounceAggregatedParallelRun", nodes, succinct) +} + +func (stenographer *FakeStenographer) AnnounceParallelRun(node int, nodes int, succinct bool) { + stenographer.registerCall("AnnounceParallelRun", node, nodes, succinct) +} + +func (stenographer *FakeStenographer) AnnounceNumberOfSpecs(specsToRun int, total int, succinct bool) { + stenographer.registerCall("AnnounceNumberOfSpecs", specsToRun, total, succinct) +} + +func (stenographer *FakeStenographer) AnnounceTotalNumberOfSpecs(total int, succinct bool) { + stenographer.registerCall("AnnounceTotalNumberOfSpecs", total, succinct) +} + +func (stenographer *FakeStenographer) AnnounceSpecRunCompletion(summary *types.SuiteSummary, succinct bool) { + stenographer.registerCall("AnnounceSpecRunCompletion", summary, succinct) +} + +func (stenographer *FakeStenographer) AnnounceSpecWillRun(spec *types.SpecSummary) { + stenographer.registerCall("AnnounceSpecWillRun", spec) +} + +func (stenographer *FakeStenographer) AnnounceBeforeSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) { + stenographer.registerCall("AnnounceBeforeSuiteFailure", summary, succinct, fullTrace) +} + +func (stenographer *FakeStenographer) AnnounceAfterSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) { + stenographer.registerCall("AnnounceAfterSuiteFailure", summary, succinct, fullTrace) +} +func (stenographer *FakeStenographer) AnnounceCapturedOutput(output string) { + stenographer.registerCall("AnnounceCapturedOutput", output) +} + +func (stenographer *FakeStenographer) AnnounceSuccessfulSpec(spec *types.SpecSummary) { + stenographer.registerCall("AnnounceSuccessfulSpec", spec) +} + +func (stenographer *FakeStenographer) AnnounceSuccessfulSlowSpec(spec *types.SpecSummary, succinct bool) { + stenographer.registerCall("AnnounceSuccessfulSlowSpec", spec, succinct) +} + +func (stenographer *FakeStenographer) AnnounceSuccessfulMeasurement(spec *types.SpecSummary, succinct bool) { + stenographer.registerCall("AnnounceSuccessfulMeasurement", spec, succinct) +} + +func (stenographer *FakeStenographer) AnnouncePendingSpec(spec *types.SpecSummary, noisy bool) { + stenographer.registerCall("AnnouncePendingSpec", spec, noisy) +} + +func (stenographer *FakeStenographer) AnnounceSkippedSpec(spec *types.SpecSummary, succinct bool, fullTrace bool) { + stenographer.registerCall("AnnounceSkippedSpec", spec, succinct, fullTrace) +} + +func (stenographer *FakeStenographer) AnnounceSpecTimedOut(spec *types.SpecSummary, succinct bool, fullTrace bool) { + stenographer.registerCall("AnnounceSpecTimedOut", spec, succinct, fullTrace) +} + +func (stenographer *FakeStenographer) AnnounceSpecPanicked(spec *types.SpecSummary, succinct bool, fullTrace bool) { + stenographer.registerCall("AnnounceSpecPanicked", spec, succinct, fullTrace) +} + +func (stenographer *FakeStenographer) AnnounceSpecFailed(spec *types.SpecSummary, succinct bool, fullTrace bool) { + stenographer.registerCall("AnnounceSpecFailed", spec, succinct, fullTrace) +} + +func (stenographer *FakeStenographer) SummarizeFailures(summaries []*types.SpecSummary) { + stenographer.registerCall("SummarizeFailures", summaries) +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/stenographer.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/stenographer.go new file mode 100644 index 00000000..638d6fbb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/stenographer.go @@ -0,0 +1,572 @@ +/* +The stenographer is used by Ginkgo's reporters to generate output. + +Move along, nothing to see here. +*/ + +package stenographer + +import ( + "fmt" + "io" + "runtime" + "strings" + + "github.com/onsi/ginkgo/types" +) + +const defaultStyle = "\x1b[0m" +const boldStyle = "\x1b[1m" +const redColor = "\x1b[91m" +const greenColor = "\x1b[32m" +const yellowColor = "\x1b[33m" +const cyanColor = "\x1b[36m" +const grayColor = "\x1b[90m" +const lightGrayColor = "\x1b[37m" + +type cursorStateType int + +const ( + cursorStateTop cursorStateType = iota + cursorStateStreaming + cursorStateMidBlock + cursorStateEndBlock +) + +type Stenographer interface { + AnnounceSuite(description string, randomSeed int64, randomizingAll bool, succinct bool) + AnnounceAggregatedParallelRun(nodes int, succinct bool) + AnnounceParallelRun(node int, nodes int, succinct bool) + AnnounceTotalNumberOfSpecs(total int, succinct bool) + AnnounceNumberOfSpecs(specsToRun int, total int, succinct bool) + AnnounceSpecRunCompletion(summary *types.SuiteSummary, succinct bool) + + AnnounceSpecWillRun(spec *types.SpecSummary) + AnnounceBeforeSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) + AnnounceAfterSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) + + AnnounceCapturedOutput(output string) + + AnnounceSuccessfulSpec(spec *types.SpecSummary) + AnnounceSuccessfulSlowSpec(spec *types.SpecSummary, succinct bool) + AnnounceSuccessfulMeasurement(spec *types.SpecSummary, succinct bool) + + AnnouncePendingSpec(spec *types.SpecSummary, noisy bool) + AnnounceSkippedSpec(spec *types.SpecSummary, succinct bool, fullTrace bool) + + AnnounceSpecTimedOut(spec *types.SpecSummary, succinct bool, fullTrace bool) + AnnounceSpecPanicked(spec *types.SpecSummary, succinct bool, fullTrace bool) + AnnounceSpecFailed(spec *types.SpecSummary, succinct bool, fullTrace bool) + + SummarizeFailures(summaries []*types.SpecSummary) +} + +func New(color bool, enableFlakes bool, writer io.Writer) Stenographer { + denoter := "•" + if runtime.GOOS == "windows" { + denoter = "+" + } + return &consoleStenographer{ + color: color, + denoter: denoter, + cursorState: cursorStateTop, + enableFlakes: enableFlakes, + w: writer, + } +} + +type consoleStenographer struct { + color bool + denoter string + cursorState cursorStateType + enableFlakes bool + w io.Writer +} + +var alternatingColors = []string{defaultStyle, grayColor} + +func (s *consoleStenographer) AnnounceSuite(description string, randomSeed int64, randomizingAll bool, succinct bool) { + if succinct { + s.print(0, "[%d] %s ", randomSeed, s.colorize(boldStyle, description)) + return + } + s.printBanner(fmt.Sprintf("Running Suite: %s", description), "=") + s.print(0, "Random Seed: %s", s.colorize(boldStyle, "%d", randomSeed)) + if randomizingAll { + s.print(0, " - Will randomize all specs") + } + s.printNewLine() +} + +func (s *consoleStenographer) AnnounceParallelRun(node int, nodes int, succinct bool) { + if succinct { + s.print(0, "- node #%d ", node) + return + } + s.println(0, + "Parallel test node %s/%s.", + s.colorize(boldStyle, "%d", node), + s.colorize(boldStyle, "%d", nodes), + ) + s.printNewLine() +} + +func (s *consoleStenographer) AnnounceAggregatedParallelRun(nodes int, succinct bool) { + if succinct { + s.print(0, "- %d nodes ", nodes) + return + } + s.println(0, + "Running in parallel across %s nodes", + s.colorize(boldStyle, "%d", nodes), + ) + s.printNewLine() +} + +func (s *consoleStenographer) AnnounceNumberOfSpecs(specsToRun int, total int, succinct bool) { + if succinct { + s.print(0, "- %d/%d specs ", specsToRun, total) + s.stream() + return + } + s.println(0, + "Will run %s of %s specs", + s.colorize(boldStyle, "%d", specsToRun), + s.colorize(boldStyle, "%d", total), + ) + + s.printNewLine() +} + +func (s *consoleStenographer) AnnounceTotalNumberOfSpecs(total int, succinct bool) { + if succinct { + s.print(0, "- %d specs ", total) + s.stream() + return + } + s.println(0, + "Will run %s specs", + s.colorize(boldStyle, "%d", total), + ) + + s.printNewLine() +} + +func (s *consoleStenographer) AnnounceSpecRunCompletion(summary *types.SuiteSummary, succinct bool) { + if succinct && summary.SuiteSucceeded { + s.print(0, " %s %s ", s.colorize(greenColor, "SUCCESS!"), summary.RunTime) + return + } + s.printNewLine() + color := greenColor + if !summary.SuiteSucceeded { + color = redColor + } + s.println(0, s.colorize(boldStyle+color, "Ran %d of %d Specs in %.3f seconds", summary.NumberOfSpecsThatWillBeRun, summary.NumberOfTotalSpecs, summary.RunTime.Seconds())) + + status := "" + if summary.SuiteSucceeded { + status = s.colorize(boldStyle+greenColor, "SUCCESS!") + } else { + status = s.colorize(boldStyle+redColor, "FAIL!") + } + + flakes := "" + if s.enableFlakes { + flakes = " | " + s.colorize(yellowColor+boldStyle, "%d Flaked", summary.NumberOfFlakedSpecs) + } + + s.print(0, + "%s -- %s | %s | %s | %s\n", + status, + s.colorize(greenColor+boldStyle, "%d Passed", summary.NumberOfPassedSpecs), + s.colorize(redColor+boldStyle, "%d Failed", summary.NumberOfFailedSpecs)+flakes, + s.colorize(yellowColor+boldStyle, "%d Pending", summary.NumberOfPendingSpecs), + s.colorize(cyanColor+boldStyle, "%d Skipped", summary.NumberOfSkippedSpecs), + ) +} + +func (s *consoleStenographer) AnnounceSpecWillRun(spec *types.SpecSummary) { + s.startBlock() + for i, text := range spec.ComponentTexts[1 : len(spec.ComponentTexts)-1] { + s.print(0, s.colorize(alternatingColors[i%2], text)+" ") + } + + indentation := 0 + if len(spec.ComponentTexts) > 2 { + indentation = 1 + s.printNewLine() + } + index := len(spec.ComponentTexts) - 1 + s.print(indentation, s.colorize(boldStyle, spec.ComponentTexts[index])) + s.printNewLine() + s.print(indentation, s.colorize(lightGrayColor, spec.ComponentCodeLocations[index].String())) + s.printNewLine() + s.midBlock() +} + +func (s *consoleStenographer) AnnounceBeforeSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) { + s.announceSetupFailure("BeforeSuite", summary, succinct, fullTrace) +} + +func (s *consoleStenographer) AnnounceAfterSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) { + s.announceSetupFailure("AfterSuite", summary, succinct, fullTrace) +} + +func (s *consoleStenographer) announceSetupFailure(name string, summary *types.SetupSummary, succinct bool, fullTrace bool) { + s.startBlock() + var message string + switch summary.State { + case types.SpecStateFailed: + message = "Failure" + case types.SpecStatePanicked: + message = "Panic" + case types.SpecStateTimedOut: + message = "Timeout" + } + + s.println(0, s.colorize(redColor+boldStyle, "%s [%.3f seconds]", message, summary.RunTime.Seconds())) + + indentation := s.printCodeLocationBlock([]string{name}, []types.CodeLocation{summary.CodeLocation}, summary.ComponentType, 0, summary.State, true) + + s.printNewLine() + s.printFailure(indentation, summary.State, summary.Failure, fullTrace) + + s.endBlock() +} + +func (s *consoleStenographer) AnnounceCapturedOutput(output string) { + if output == "" { + return + } + + s.startBlock() + s.println(0, output) + s.midBlock() +} + +func (s *consoleStenographer) AnnounceSuccessfulSpec(spec *types.SpecSummary) { + s.print(0, s.colorize(greenColor, s.denoter)) + s.stream() +} + +func (s *consoleStenographer) AnnounceSuccessfulSlowSpec(spec *types.SpecSummary, succinct bool) { + s.printBlockWithMessage( + s.colorize(greenColor, "%s [SLOW TEST:%.3f seconds]", s.denoter, spec.RunTime.Seconds()), + "", + spec, + succinct, + ) +} + +func (s *consoleStenographer) AnnounceSuccessfulMeasurement(spec *types.SpecSummary, succinct bool) { + s.printBlockWithMessage( + s.colorize(greenColor, "%s [MEASUREMENT]", s.denoter), + s.measurementReport(spec, succinct), + spec, + succinct, + ) +} + +func (s *consoleStenographer) AnnouncePendingSpec(spec *types.SpecSummary, noisy bool) { + if noisy { + s.printBlockWithMessage( + s.colorize(yellowColor, "P [PENDING]"), + "", + spec, + false, + ) + } else { + s.print(0, s.colorize(yellowColor, "P")) + s.stream() + } +} + +func (s *consoleStenographer) AnnounceSkippedSpec(spec *types.SpecSummary, succinct bool, fullTrace bool) { + // Skips at runtime will have a non-empty spec.Failure. All others should be succinct. + if succinct || spec.Failure == (types.SpecFailure{}) { + s.print(0, s.colorize(cyanColor, "S")) + s.stream() + } else { + s.startBlock() + s.println(0, s.colorize(cyanColor+boldStyle, "S [SKIPPING]%s [%.3f seconds]", s.failureContext(spec.Failure.ComponentType), spec.RunTime.Seconds())) + + indentation := s.printCodeLocationBlock(spec.ComponentTexts, spec.ComponentCodeLocations, spec.Failure.ComponentType, spec.Failure.ComponentIndex, spec.State, succinct) + + s.printNewLine() + s.printSkip(indentation, spec.Failure) + s.endBlock() + } +} + +func (s *consoleStenographer) AnnounceSpecTimedOut(spec *types.SpecSummary, succinct bool, fullTrace bool) { + s.printSpecFailure(fmt.Sprintf("%s... Timeout", s.denoter), spec, succinct, fullTrace) +} + +func (s *consoleStenographer) AnnounceSpecPanicked(spec *types.SpecSummary, succinct bool, fullTrace bool) { + s.printSpecFailure(fmt.Sprintf("%s! Panic", s.denoter), spec, succinct, fullTrace) +} + +func (s *consoleStenographer) AnnounceSpecFailed(spec *types.SpecSummary, succinct bool, fullTrace bool) { + s.printSpecFailure(fmt.Sprintf("%s Failure", s.denoter), spec, succinct, fullTrace) +} + +func (s *consoleStenographer) SummarizeFailures(summaries []*types.SpecSummary) { + failingSpecs := []*types.SpecSummary{} + + for _, summary := range summaries { + if summary.HasFailureState() { + failingSpecs = append(failingSpecs, summary) + } + } + + if len(failingSpecs) == 0 { + return + } + + s.printNewLine() + s.printNewLine() + plural := "s" + if len(failingSpecs) == 1 { + plural = "" + } + s.println(0, s.colorize(redColor+boldStyle, "Summarizing %d Failure%s:", len(failingSpecs), plural)) + for _, summary := range failingSpecs { + s.printNewLine() + if summary.HasFailureState() { + if summary.TimedOut() { + s.print(0, s.colorize(redColor+boldStyle, "[Timeout...] ")) + } else if summary.Panicked() { + s.print(0, s.colorize(redColor+boldStyle, "[Panic!] ")) + } else if summary.Failed() { + s.print(0, s.colorize(redColor+boldStyle, "[Fail] ")) + } + s.printSpecContext(summary.ComponentTexts, summary.ComponentCodeLocations, summary.Failure.ComponentType, summary.Failure.ComponentIndex, summary.State, true) + s.printNewLine() + s.println(0, s.colorize(lightGrayColor, summary.Failure.Location.String())) + } + } +} + +func (s *consoleStenographer) startBlock() { + if s.cursorState == cursorStateStreaming { + s.printNewLine() + s.printDelimiter() + } else if s.cursorState == cursorStateMidBlock { + s.printNewLine() + } +} + +func (s *consoleStenographer) midBlock() { + s.cursorState = cursorStateMidBlock +} + +func (s *consoleStenographer) endBlock() { + s.printDelimiter() + s.cursorState = cursorStateEndBlock +} + +func (s *consoleStenographer) stream() { + s.cursorState = cursorStateStreaming +} + +func (s *consoleStenographer) printBlockWithMessage(header string, message string, spec *types.SpecSummary, succinct bool) { + s.startBlock() + s.println(0, header) + + indentation := s.printCodeLocationBlock(spec.ComponentTexts, spec.ComponentCodeLocations, types.SpecComponentTypeInvalid, 0, spec.State, succinct) + + if message != "" { + s.printNewLine() + s.println(indentation, message) + } + + s.endBlock() +} + +func (s *consoleStenographer) printSpecFailure(message string, spec *types.SpecSummary, succinct bool, fullTrace bool) { + s.startBlock() + s.println(0, s.colorize(redColor+boldStyle, "%s%s [%.3f seconds]", message, s.failureContext(spec.Failure.ComponentType), spec.RunTime.Seconds())) + + indentation := s.printCodeLocationBlock(spec.ComponentTexts, spec.ComponentCodeLocations, spec.Failure.ComponentType, spec.Failure.ComponentIndex, spec.State, succinct) + + s.printNewLine() + s.printFailure(indentation, spec.State, spec.Failure, fullTrace) + s.endBlock() +} + +func (s *consoleStenographer) failureContext(failedComponentType types.SpecComponentType) string { + switch failedComponentType { + case types.SpecComponentTypeBeforeSuite: + return " in Suite Setup (BeforeSuite)" + case types.SpecComponentTypeAfterSuite: + return " in Suite Teardown (AfterSuite)" + case types.SpecComponentTypeBeforeEach: + return " in Spec Setup (BeforeEach)" + case types.SpecComponentTypeJustBeforeEach: + return " in Spec Setup (JustBeforeEach)" + case types.SpecComponentTypeAfterEach: + return " in Spec Teardown (AfterEach)" + } + + return "" +} + +func (s *consoleStenographer) printSkip(indentation int, spec types.SpecFailure) { + s.println(indentation, s.colorize(cyanColor, spec.Message)) + s.printNewLine() + s.println(indentation, spec.Location.String()) +} + +func (s *consoleStenographer) printFailure(indentation int, state types.SpecState, failure types.SpecFailure, fullTrace bool) { + if state == types.SpecStatePanicked { + s.println(indentation, s.colorize(redColor+boldStyle, failure.Message)) + s.println(indentation, s.colorize(redColor, failure.ForwardedPanic)) + s.println(indentation, failure.Location.String()) + s.printNewLine() + s.println(indentation, s.colorize(redColor, "Full Stack Trace")) + s.println(indentation, failure.Location.FullStackTrace) + } else { + s.println(indentation, s.colorize(redColor, failure.Message)) + s.printNewLine() + s.println(indentation, failure.Location.String()) + if fullTrace { + s.printNewLine() + s.println(indentation, s.colorize(redColor, "Full Stack Trace")) + s.println(indentation, failure.Location.FullStackTrace) + } + } +} + +func (s *consoleStenographer) printSpecContext(componentTexts []string, componentCodeLocations []types.CodeLocation, failedComponentType types.SpecComponentType, failedComponentIndex int, state types.SpecState, succinct bool) int { + startIndex := 1 + indentation := 0 + + if len(componentTexts) == 1 { + startIndex = 0 + } + + for i := startIndex; i < len(componentTexts); i++ { + if (state.IsFailure() || state == types.SpecStateSkipped) && i == failedComponentIndex { + color := redColor + if state == types.SpecStateSkipped { + color = cyanColor + } + blockType := "" + switch failedComponentType { + case types.SpecComponentTypeBeforeSuite: + blockType = "BeforeSuite" + case types.SpecComponentTypeAfterSuite: + blockType = "AfterSuite" + case types.SpecComponentTypeBeforeEach: + blockType = "BeforeEach" + case types.SpecComponentTypeJustBeforeEach: + blockType = "JustBeforeEach" + case types.SpecComponentTypeAfterEach: + blockType = "AfterEach" + case types.SpecComponentTypeIt: + blockType = "It" + case types.SpecComponentTypeMeasure: + blockType = "Measurement" + } + if succinct { + s.print(0, s.colorize(color+boldStyle, "[%s] %s ", blockType, componentTexts[i])) + } else { + s.println(indentation, s.colorize(color+boldStyle, "%s [%s]", componentTexts[i], blockType)) + s.println(indentation, s.colorize(grayColor, "%s", componentCodeLocations[i])) + } + } else { + if succinct { + s.print(0, s.colorize(alternatingColors[i%2], "%s ", componentTexts[i])) + } else { + s.println(indentation, componentTexts[i]) + s.println(indentation, s.colorize(grayColor, "%s", componentCodeLocations[i])) + } + } + indentation++ + } + + return indentation +} + +func (s *consoleStenographer) printCodeLocationBlock(componentTexts []string, componentCodeLocations []types.CodeLocation, failedComponentType types.SpecComponentType, failedComponentIndex int, state types.SpecState, succinct bool) int { + indentation := s.printSpecContext(componentTexts, componentCodeLocations, failedComponentType, failedComponentIndex, state, succinct) + + if succinct { + if len(componentTexts) > 0 { + s.printNewLine() + s.print(0, s.colorize(lightGrayColor, "%s", componentCodeLocations[len(componentCodeLocations)-1])) + } + s.printNewLine() + indentation = 1 + } else { + indentation-- + } + + return indentation +} + +func (s *consoleStenographer) orderedMeasurementKeys(measurements map[string]*types.SpecMeasurement) []string { + orderedKeys := make([]string, len(measurements)) + for key, measurement := range measurements { + orderedKeys[measurement.Order] = key + } + return orderedKeys +} + +func (s *consoleStenographer) measurementReport(spec *types.SpecSummary, succinct bool) string { + if len(spec.Measurements) == 0 { + return "Found no measurements" + } + + message := []string{} + orderedKeys := s.orderedMeasurementKeys(spec.Measurements) + + if succinct { + message = append(message, fmt.Sprintf("%s samples:", s.colorize(boldStyle, "%d", spec.NumberOfSamples))) + for _, key := range orderedKeys { + measurement := spec.Measurements[key] + message = append(message, fmt.Sprintf(" %s - %s: %s%s, %s: %s%s ± %s%s, %s: %s%s", + s.colorize(boldStyle, "%s", measurement.Name), + measurement.SmallestLabel, + s.colorize(greenColor, measurement.PrecisionFmt(), measurement.Smallest), + measurement.Units, + measurement.AverageLabel, + s.colorize(cyanColor, measurement.PrecisionFmt(), measurement.Average), + measurement.Units, + s.colorize(cyanColor, measurement.PrecisionFmt(), measurement.StdDeviation), + measurement.Units, + measurement.LargestLabel, + s.colorize(redColor, measurement.PrecisionFmt(), measurement.Largest), + measurement.Units, + )) + } + } else { + message = append(message, fmt.Sprintf("Ran %s samples:", s.colorize(boldStyle, "%d", spec.NumberOfSamples))) + for _, key := range orderedKeys { + measurement := spec.Measurements[key] + info := "" + if measurement.Info != nil { + message = append(message, fmt.Sprintf("%v", measurement.Info)) + } + + message = append(message, fmt.Sprintf("%s:\n%s %s: %s%s\n %s: %s%s\n %s: %s%s ± %s%s", + s.colorize(boldStyle, "%s", measurement.Name), + info, + measurement.SmallestLabel, + s.colorize(greenColor, measurement.PrecisionFmt(), measurement.Smallest), + measurement.Units, + measurement.LargestLabel, + s.colorize(redColor, measurement.PrecisionFmt(), measurement.Largest), + measurement.Units, + measurement.AverageLabel, + s.colorize(cyanColor, measurement.PrecisionFmt(), measurement.Average), + measurement.Units, + s.colorize(cyanColor, measurement.PrecisionFmt(), measurement.StdDeviation), + measurement.Units, + )) + } + } + + return strings.Join(message, "\n") +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/LICENSE b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/LICENSE new file mode 100644 index 00000000..91b5cef3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Yasuhiro Matsumoto + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/README.md b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/README.md new file mode 100644 index 00000000..e84226a7 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/README.md @@ -0,0 +1,43 @@ +# go-colorable + +Colorable writer for windows. + +For example, most of logger packages doesn't show colors on windows. (I know we can do it with ansicon. But I don't want.) +This package is possible to handle escape sequence for ansi color on windows. + +## Too Bad! + +![](https://raw.githubusercontent.com/mattn/go-colorable/gh-pages/bad.png) + + +## So Good! + +![](https://raw.githubusercontent.com/mattn/go-colorable/gh-pages/good.png) + +## Usage + +```go +logrus.SetFormatter(&logrus.TextFormatter{ForceColors: true}) +logrus.SetOutput(colorable.NewColorableStdout()) + +logrus.Info("succeeded") +logrus.Warn("not correct") +logrus.Error("something error") +logrus.Fatal("panic") +``` + +You can compile above code on non-windows OSs. + +## Installation + +``` +$ go get github.com/mattn/go-colorable +``` + +# License + +MIT + +# Author + +Yasuhiro Matsumoto (a.k.a mattn) diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_others.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_others.go new file mode 100644 index 00000000..52d6653b --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_others.go @@ -0,0 +1,24 @@ +// +build !windows + +package colorable + +import ( + "io" + "os" +) + +func NewColorable(file *os.File) io.Writer { + if file == nil { + panic("nil passed instead of *os.File to NewColorable()") + } + + return file +} + +func NewColorableStdout() io.Writer { + return os.Stdout +} + +func NewColorableStderr() io.Writer { + return os.Stderr +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_windows.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_windows.go new file mode 100644 index 00000000..10880092 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_windows.go @@ -0,0 +1,783 @@ +package colorable + +import ( + "bytes" + "fmt" + "io" + "math" + "os" + "strconv" + "strings" + "syscall" + "unsafe" + + "github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty" +) + +const ( + foregroundBlue = 0x1 + foregroundGreen = 0x2 + foregroundRed = 0x4 + foregroundIntensity = 0x8 + foregroundMask = (foregroundRed | foregroundBlue | foregroundGreen | foregroundIntensity) + backgroundBlue = 0x10 + backgroundGreen = 0x20 + backgroundRed = 0x40 + backgroundIntensity = 0x80 + backgroundMask = (backgroundRed | backgroundBlue | backgroundGreen | backgroundIntensity) +) + +type wchar uint16 +type short int16 +type dword uint32 +type word uint16 + +type coord struct { + x short + y short +} + +type smallRect struct { + left short + top short + right short + bottom short +} + +type consoleScreenBufferInfo struct { + size coord + cursorPosition coord + attributes word + window smallRect + maximumWindowSize coord +} + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") + procSetConsoleTextAttribute = kernel32.NewProc("SetConsoleTextAttribute") + procSetConsoleCursorPosition = kernel32.NewProc("SetConsoleCursorPosition") + procFillConsoleOutputCharacter = kernel32.NewProc("FillConsoleOutputCharacterW") + procFillConsoleOutputAttribute = kernel32.NewProc("FillConsoleOutputAttribute") +) + +type Writer struct { + out io.Writer + handle syscall.Handle + lastbuf bytes.Buffer + oldattr word +} + +func NewColorable(file *os.File) io.Writer { + if file == nil { + panic("nil passed instead of *os.File to NewColorable()") + } + + if isatty.IsTerminal(file.Fd()) { + var csbi consoleScreenBufferInfo + handle := syscall.Handle(file.Fd()) + procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi))) + return &Writer{out: file, handle: handle, oldattr: csbi.attributes} + } else { + return file + } +} + +func NewColorableStdout() io.Writer { + return NewColorable(os.Stdout) +} + +func NewColorableStderr() io.Writer { + return NewColorable(os.Stderr) +} + +var color256 = map[int]int{ + 0: 0x000000, + 1: 0x800000, + 2: 0x008000, + 3: 0x808000, + 4: 0x000080, + 5: 0x800080, + 6: 0x008080, + 7: 0xc0c0c0, + 8: 0x808080, + 9: 0xff0000, + 10: 0x00ff00, + 11: 0xffff00, + 12: 0x0000ff, + 13: 0xff00ff, + 14: 0x00ffff, + 15: 0xffffff, + 16: 0x000000, + 17: 0x00005f, + 18: 0x000087, + 19: 0x0000af, + 20: 0x0000d7, + 21: 0x0000ff, + 22: 0x005f00, + 23: 0x005f5f, + 24: 0x005f87, + 25: 0x005faf, + 26: 0x005fd7, + 27: 0x005fff, + 28: 0x008700, + 29: 0x00875f, + 30: 0x008787, + 31: 0x0087af, + 32: 0x0087d7, + 33: 0x0087ff, + 34: 0x00af00, + 35: 0x00af5f, + 36: 0x00af87, + 37: 0x00afaf, + 38: 0x00afd7, + 39: 0x00afff, + 40: 0x00d700, + 41: 0x00d75f, + 42: 0x00d787, + 43: 0x00d7af, + 44: 0x00d7d7, + 45: 0x00d7ff, + 46: 0x00ff00, + 47: 0x00ff5f, + 48: 0x00ff87, + 49: 0x00ffaf, + 50: 0x00ffd7, + 51: 0x00ffff, + 52: 0x5f0000, + 53: 0x5f005f, + 54: 0x5f0087, + 55: 0x5f00af, + 56: 0x5f00d7, + 57: 0x5f00ff, + 58: 0x5f5f00, + 59: 0x5f5f5f, + 60: 0x5f5f87, + 61: 0x5f5faf, + 62: 0x5f5fd7, + 63: 0x5f5fff, + 64: 0x5f8700, + 65: 0x5f875f, + 66: 0x5f8787, + 67: 0x5f87af, + 68: 0x5f87d7, + 69: 0x5f87ff, + 70: 0x5faf00, + 71: 0x5faf5f, + 72: 0x5faf87, + 73: 0x5fafaf, + 74: 0x5fafd7, + 75: 0x5fafff, + 76: 0x5fd700, + 77: 0x5fd75f, + 78: 0x5fd787, + 79: 0x5fd7af, + 80: 0x5fd7d7, + 81: 0x5fd7ff, + 82: 0x5fff00, + 83: 0x5fff5f, + 84: 0x5fff87, + 85: 0x5fffaf, + 86: 0x5fffd7, + 87: 0x5fffff, + 88: 0x870000, + 89: 0x87005f, + 90: 0x870087, + 91: 0x8700af, + 92: 0x8700d7, + 93: 0x8700ff, + 94: 0x875f00, + 95: 0x875f5f, + 96: 0x875f87, + 97: 0x875faf, + 98: 0x875fd7, + 99: 0x875fff, + 100: 0x878700, + 101: 0x87875f, + 102: 0x878787, + 103: 0x8787af, + 104: 0x8787d7, + 105: 0x8787ff, + 106: 0x87af00, + 107: 0x87af5f, + 108: 0x87af87, + 109: 0x87afaf, + 110: 0x87afd7, + 111: 0x87afff, + 112: 0x87d700, + 113: 0x87d75f, + 114: 0x87d787, + 115: 0x87d7af, + 116: 0x87d7d7, + 117: 0x87d7ff, + 118: 0x87ff00, + 119: 0x87ff5f, + 120: 0x87ff87, + 121: 0x87ffaf, + 122: 0x87ffd7, + 123: 0x87ffff, + 124: 0xaf0000, + 125: 0xaf005f, + 126: 0xaf0087, + 127: 0xaf00af, + 128: 0xaf00d7, + 129: 0xaf00ff, + 130: 0xaf5f00, + 131: 0xaf5f5f, + 132: 0xaf5f87, + 133: 0xaf5faf, + 134: 0xaf5fd7, + 135: 0xaf5fff, + 136: 0xaf8700, + 137: 0xaf875f, + 138: 0xaf8787, + 139: 0xaf87af, + 140: 0xaf87d7, + 141: 0xaf87ff, + 142: 0xafaf00, + 143: 0xafaf5f, + 144: 0xafaf87, + 145: 0xafafaf, + 146: 0xafafd7, + 147: 0xafafff, + 148: 0xafd700, + 149: 0xafd75f, + 150: 0xafd787, + 151: 0xafd7af, + 152: 0xafd7d7, + 153: 0xafd7ff, + 154: 0xafff00, + 155: 0xafff5f, + 156: 0xafff87, + 157: 0xafffaf, + 158: 0xafffd7, + 159: 0xafffff, + 160: 0xd70000, + 161: 0xd7005f, + 162: 0xd70087, + 163: 0xd700af, + 164: 0xd700d7, + 165: 0xd700ff, + 166: 0xd75f00, + 167: 0xd75f5f, + 168: 0xd75f87, + 169: 0xd75faf, + 170: 0xd75fd7, + 171: 0xd75fff, + 172: 0xd78700, + 173: 0xd7875f, + 174: 0xd78787, + 175: 0xd787af, + 176: 0xd787d7, + 177: 0xd787ff, + 178: 0xd7af00, + 179: 0xd7af5f, + 180: 0xd7af87, + 181: 0xd7afaf, + 182: 0xd7afd7, + 183: 0xd7afff, + 184: 0xd7d700, + 185: 0xd7d75f, + 186: 0xd7d787, + 187: 0xd7d7af, + 188: 0xd7d7d7, + 189: 0xd7d7ff, + 190: 0xd7ff00, + 191: 0xd7ff5f, + 192: 0xd7ff87, + 193: 0xd7ffaf, + 194: 0xd7ffd7, + 195: 0xd7ffff, + 196: 0xff0000, + 197: 0xff005f, + 198: 0xff0087, + 199: 0xff00af, + 200: 0xff00d7, + 201: 0xff00ff, + 202: 0xff5f00, + 203: 0xff5f5f, + 204: 0xff5f87, + 205: 0xff5faf, + 206: 0xff5fd7, + 207: 0xff5fff, + 208: 0xff8700, + 209: 0xff875f, + 210: 0xff8787, + 211: 0xff87af, + 212: 0xff87d7, + 213: 0xff87ff, + 214: 0xffaf00, + 215: 0xffaf5f, + 216: 0xffaf87, + 217: 0xffafaf, + 218: 0xffafd7, + 219: 0xffafff, + 220: 0xffd700, + 221: 0xffd75f, + 222: 0xffd787, + 223: 0xffd7af, + 224: 0xffd7d7, + 225: 0xffd7ff, + 226: 0xffff00, + 227: 0xffff5f, + 228: 0xffff87, + 229: 0xffffaf, + 230: 0xffffd7, + 231: 0xffffff, + 232: 0x080808, + 233: 0x121212, + 234: 0x1c1c1c, + 235: 0x262626, + 236: 0x303030, + 237: 0x3a3a3a, + 238: 0x444444, + 239: 0x4e4e4e, + 240: 0x585858, + 241: 0x626262, + 242: 0x6c6c6c, + 243: 0x767676, + 244: 0x808080, + 245: 0x8a8a8a, + 246: 0x949494, + 247: 0x9e9e9e, + 248: 0xa8a8a8, + 249: 0xb2b2b2, + 250: 0xbcbcbc, + 251: 0xc6c6c6, + 252: 0xd0d0d0, + 253: 0xdadada, + 254: 0xe4e4e4, + 255: 0xeeeeee, +} + +func (w *Writer) Write(data []byte) (n int, err error) { + var csbi consoleScreenBufferInfo + procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + + er := bytes.NewBuffer(data) +loop: + for { + r1, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + if r1 == 0 { + break loop + } + + c1, _, err := er.ReadRune() + if err != nil { + break loop + } + if c1 != 0x1b { + fmt.Fprint(w.out, string(c1)) + continue + } + c2, _, err := er.ReadRune() + if err != nil { + w.lastbuf.WriteRune(c1) + break loop + } + if c2 != 0x5b { + w.lastbuf.WriteRune(c1) + w.lastbuf.WriteRune(c2) + continue + } + + var buf bytes.Buffer + var m rune + for { + c, _, err := er.ReadRune() + if err != nil { + w.lastbuf.WriteRune(c1) + w.lastbuf.WriteRune(c2) + w.lastbuf.Write(buf.Bytes()) + break loop + } + if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' { + m = c + break + } + buf.Write([]byte(string(c))) + } + + var csbi consoleScreenBufferInfo + switch m { + case 'A': + n, err = strconv.Atoi(buf.String()) + if err != nil { + continue + } + procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + csbi.cursorPosition.y -= short(n) + procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) + case 'B': + n, err = strconv.Atoi(buf.String()) + if err != nil { + continue + } + procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + csbi.cursorPosition.y += short(n) + procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) + case 'C': + n, err = strconv.Atoi(buf.String()) + if err != nil { + continue + } + procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + csbi.cursorPosition.x -= short(n) + procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) + case 'D': + n, err = strconv.Atoi(buf.String()) + if err != nil { + continue + } + if n, err = strconv.Atoi(buf.String()); err == nil { + var csbi consoleScreenBufferInfo + procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + csbi.cursorPosition.x += short(n) + procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) + } + case 'E': + n, err = strconv.Atoi(buf.String()) + if err != nil { + continue + } + procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + csbi.cursorPosition.x = 0 + csbi.cursorPosition.y += short(n) + procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) + case 'F': + n, err = strconv.Atoi(buf.String()) + if err != nil { + continue + } + procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + csbi.cursorPosition.x = 0 + csbi.cursorPosition.y -= short(n) + procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) + case 'G': + n, err = strconv.Atoi(buf.String()) + if err != nil { + continue + } + procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) + csbi.cursorPosition.x = short(n) + procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) + case 'H': + token := strings.Split(buf.String(), ";") + if len(token) != 2 { + continue + } + n1, err := strconv.Atoi(token[0]) + if err != nil { + continue + } + n2, err := strconv.Atoi(token[1]) + if err != nil { + continue + } + csbi.cursorPosition.x = short(n2) + csbi.cursorPosition.x = short(n1) + procSetConsoleCursorPosition.Call(uintptr(w.handle), *(*uintptr)(unsafe.Pointer(&csbi.cursorPosition))) + case 'J': + n, err := strconv.Atoi(buf.String()) + if err != nil { + continue + } + var cursor coord + switch n { + case 0: + cursor = coord{x: csbi.cursorPosition.x, y: csbi.cursorPosition.y} + case 1: + cursor = coord{x: csbi.window.left, y: csbi.window.top} + case 2: + cursor = coord{x: csbi.window.left, y: csbi.window.top} + } + var count, written dword + count = dword(csbi.size.x - csbi.cursorPosition.x + (csbi.size.y-csbi.cursorPosition.y)*csbi.size.x) + procFillConsoleOutputCharacter.Call(uintptr(w.handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written))) + procFillConsoleOutputAttribute.Call(uintptr(w.handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written))) + case 'K': + n, err := strconv.Atoi(buf.String()) + if err != nil { + continue + } + var cursor coord + switch n { + case 0: + cursor = coord{x: csbi.cursorPosition.x, y: csbi.cursorPosition.y} + case 1: + cursor = coord{x: csbi.window.left, y: csbi.window.top + csbi.cursorPosition.y} + case 2: + cursor = coord{x: csbi.window.left, y: csbi.window.top + csbi.cursorPosition.y} + } + var count, written dword + count = dword(csbi.size.x - csbi.cursorPosition.x) + procFillConsoleOutputCharacter.Call(uintptr(w.handle), uintptr(' '), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written))) + procFillConsoleOutputAttribute.Call(uintptr(w.handle), uintptr(csbi.attributes), uintptr(count), *(*uintptr)(unsafe.Pointer(&cursor)), uintptr(unsafe.Pointer(&written))) + case 'm': + attr := csbi.attributes + cs := buf.String() + if cs == "" { + procSetConsoleTextAttribute.Call(uintptr(w.handle), uintptr(w.oldattr)) + continue + } + token := strings.Split(cs, ";") + for i := 0; i < len(token); i += 1 { + ns := token[i] + if n, err = strconv.Atoi(ns); err == nil { + switch { + case n == 0 || n == 100: + attr = w.oldattr + case 1 <= n && n <= 5: + attr |= foregroundIntensity + case n == 7: + attr = ((attr & foregroundMask) << 4) | ((attr & backgroundMask) >> 4) + case 22 == n || n == 25 || n == 25: + attr |= foregroundIntensity + case n == 27: + attr = ((attr & foregroundMask) << 4) | ((attr & backgroundMask) >> 4) + case 30 <= n && n <= 37: + attr = (attr & backgroundMask) + if (n-30)&1 != 0 { + attr |= foregroundRed + } + if (n-30)&2 != 0 { + attr |= foregroundGreen + } + if (n-30)&4 != 0 { + attr |= foregroundBlue + } + case n == 38: // set foreground color. + if i < len(token)-2 && (token[i+1] == "5" || token[i+1] == "05") { + if n256, err := strconv.Atoi(token[i+2]); err == nil { + if n256foreAttr == nil { + n256setup() + } + attr &= backgroundMask + attr |= n256foreAttr[n256] + i += 2 + } + } else { + attr = attr & (w.oldattr & backgroundMask) + } + case n == 39: // reset foreground color. + attr &= backgroundMask + attr |= w.oldattr & foregroundMask + case 40 <= n && n <= 47: + attr = (attr & foregroundMask) + if (n-40)&1 != 0 { + attr |= backgroundRed + } + if (n-40)&2 != 0 { + attr |= backgroundGreen + } + if (n-40)&4 != 0 { + attr |= backgroundBlue + } + case n == 48: // set background color. + if i < len(token)-2 && token[i+1] == "5" { + if n256, err := strconv.Atoi(token[i+2]); err == nil { + if n256backAttr == nil { + n256setup() + } + attr &= foregroundMask + attr |= n256backAttr[n256] + i += 2 + } + } else { + attr = attr & (w.oldattr & foregroundMask) + } + case n == 49: // reset foreground color. + attr &= foregroundMask + attr |= w.oldattr & backgroundMask + case 90 <= n && n <= 97: + attr = (attr & backgroundMask) + attr |= foregroundIntensity + if (n-90)&1 != 0 { + attr |= foregroundRed + } + if (n-90)&2 != 0 { + attr |= foregroundGreen + } + if (n-90)&4 != 0 { + attr |= foregroundBlue + } + case 100 <= n && n <= 107: + attr = (attr & foregroundMask) + attr |= backgroundIntensity + if (n-100)&1 != 0 { + attr |= backgroundRed + } + if (n-100)&2 != 0 { + attr |= backgroundGreen + } + if (n-100)&4 != 0 { + attr |= backgroundBlue + } + } + procSetConsoleTextAttribute.Call(uintptr(w.handle), uintptr(attr)) + } + } + } + } + return len(data) - w.lastbuf.Len(), nil +} + +type consoleColor struct { + rgb int + red bool + green bool + blue bool + intensity bool +} + +func (c consoleColor) foregroundAttr() (attr word) { + if c.red { + attr |= foregroundRed + } + if c.green { + attr |= foregroundGreen + } + if c.blue { + attr |= foregroundBlue + } + if c.intensity { + attr |= foregroundIntensity + } + return +} + +func (c consoleColor) backgroundAttr() (attr word) { + if c.red { + attr |= backgroundRed + } + if c.green { + attr |= backgroundGreen + } + if c.blue { + attr |= backgroundBlue + } + if c.intensity { + attr |= backgroundIntensity + } + return +} + +var color16 = []consoleColor{ + consoleColor{0x000000, false, false, false, false}, + consoleColor{0x000080, false, false, true, false}, + consoleColor{0x008000, false, true, false, false}, + consoleColor{0x008080, false, true, true, false}, + consoleColor{0x800000, true, false, false, false}, + consoleColor{0x800080, true, false, true, false}, + consoleColor{0x808000, true, true, false, false}, + consoleColor{0xc0c0c0, true, true, true, false}, + consoleColor{0x808080, false, false, false, true}, + consoleColor{0x0000ff, false, false, true, true}, + consoleColor{0x00ff00, false, true, false, true}, + consoleColor{0x00ffff, false, true, true, true}, + consoleColor{0xff0000, true, false, false, true}, + consoleColor{0xff00ff, true, false, true, true}, + consoleColor{0xffff00, true, true, false, true}, + consoleColor{0xffffff, true, true, true, true}, +} + +type hsv struct { + h, s, v float32 +} + +func (a hsv) dist(b hsv) float32 { + dh := a.h - b.h + switch { + case dh > 0.5: + dh = 1 - dh + case dh < -0.5: + dh = -1 - dh + } + ds := a.s - b.s + dv := a.v - b.v + return float32(math.Sqrt(float64(dh*dh + ds*ds + dv*dv))) +} + +func toHSV(rgb int) hsv { + r, g, b := float32((rgb&0xFF0000)>>16)/256.0, + float32((rgb&0x00FF00)>>8)/256.0, + float32(rgb&0x0000FF)/256.0 + min, max := minmax3f(r, g, b) + h := max - min + if h > 0 { + if max == r { + h = (g - b) / h + if h < 0 { + h += 6 + } + } else if max == g { + h = 2 + (b-r)/h + } else { + h = 4 + (r-g)/h + } + } + h /= 6.0 + s := max - min + if max != 0 { + s /= max + } + v := max + return hsv{h: h, s: s, v: v} +} + +type hsvTable []hsv + +func toHSVTable(rgbTable []consoleColor) hsvTable { + t := make(hsvTable, len(rgbTable)) + for i, c := range rgbTable { + t[i] = toHSV(c.rgb) + } + return t +} + +func (t hsvTable) find(rgb int) consoleColor { + hsv := toHSV(rgb) + n := 7 + l := float32(5.0) + for i, p := range t { + d := hsv.dist(p) + if d < l { + l, n = d, i + } + } + return color16[n] +} + +func minmax3f(a, b, c float32) (min, max float32) { + if a < b { + if b < c { + return a, c + } else if a < c { + return a, b + } else { + return c, b + } + } else { + if a < c { + return b, c + } else if b < c { + return b, a + } else { + return c, a + } + } +} + +var n256foreAttr []word +var n256backAttr []word + +func n256setup() { + n256foreAttr = make([]word, 256) + n256backAttr = make([]word, 256) + t := toHSVTable(color16) + for i, rgb := range color256 { + c := t.find(rgb) + n256foreAttr[i] = c.foregroundAttr() + n256backAttr[i] = c.backgroundAttr() + } +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/noncolorable.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/noncolorable.go new file mode 100644 index 00000000..fb976dbd --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/noncolorable.go @@ -0,0 +1,57 @@ +package colorable + +import ( + "bytes" + "fmt" + "io" +) + +type NonColorable struct { + out io.Writer + lastbuf bytes.Buffer +} + +func NewNonColorable(w io.Writer) io.Writer { + return &NonColorable{out: w} +} + +func (w *NonColorable) Write(data []byte) (n int, err error) { + er := bytes.NewBuffer(data) +loop: + for { + c1, _, err := er.ReadRune() + if err != nil { + break loop + } + if c1 != 0x1b { + fmt.Fprint(w.out, string(c1)) + continue + } + c2, _, err := er.ReadRune() + if err != nil { + w.lastbuf.WriteRune(c1) + break loop + } + if c2 != 0x5b { + w.lastbuf.WriteRune(c1) + w.lastbuf.WriteRune(c2) + continue + } + + var buf bytes.Buffer + for { + c, _, err := er.ReadRune() + if err != nil { + w.lastbuf.WriteRune(c1) + w.lastbuf.WriteRune(c2) + w.lastbuf.Write(buf.Bytes()) + break loop + } + if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' { + break + } + buf.Write([]byte(string(c))) + } + } + return len(data) - w.lastbuf.Len(), nil +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/LICENSE b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/LICENSE new file mode 100644 index 00000000..65dc692b --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/LICENSE @@ -0,0 +1,9 @@ +Copyright (c) Yasuhiro MATSUMOTO + +MIT License (Expat) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/README.md b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/README.md new file mode 100644 index 00000000..74845de4 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/README.md @@ -0,0 +1,37 @@ +# go-isatty + +isatty for golang + +## Usage + +```go +package main + +import ( + "fmt" + "github.com/mattn/go-isatty" + "os" +) + +func main() { + if isatty.IsTerminal(os.Stdout.Fd()) { + fmt.Println("Is Terminal") + } else { + fmt.Println("Is Not Terminal") + } +} +``` + +## Installation + +``` +$ go get github.com/mattn/go-isatty +``` + +# License + +MIT + +# Author + +Yasuhiro Matsumoto (a.k.a mattn) diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/doc.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/doc.go new file mode 100644 index 00000000..17d4f90e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/doc.go @@ -0,0 +1,2 @@ +// Package isatty implements interface to isatty +package isatty diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_appengine.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_appengine.go new file mode 100644 index 00000000..83c58877 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_appengine.go @@ -0,0 +1,9 @@ +// +build appengine + +package isatty + +// IsTerminal returns true if the file descriptor is terminal which +// is always false on on appengine classic which is a sandboxed PaaS. +func IsTerminal(fd uintptr) bool { + return false +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_bsd.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_bsd.go new file mode 100644 index 00000000..98ffe86a --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_bsd.go @@ -0,0 +1,18 @@ +// +build darwin freebsd openbsd netbsd +// +build !appengine + +package isatty + +import ( + "syscall" + "unsafe" +) + +const ioctlReadTermios = syscall.TIOCGETA + +// IsTerminal return true if the file descriptor is terminal. +func IsTerminal(fd uintptr) bool { + var termios syscall.Termios + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) + return err == 0 +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_linux.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_linux.go new file mode 100644 index 00000000..9d24bac1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_linux.go @@ -0,0 +1,18 @@ +// +build linux +// +build !appengine + +package isatty + +import ( + "syscall" + "unsafe" +) + +const ioctlReadTermios = syscall.TCGETS + +// IsTerminal return true if the file descriptor is terminal. +func IsTerminal(fd uintptr) bool { + var termios syscall.Termios + _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) + return err == 0 +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_solaris.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_solaris.go new file mode 100644 index 00000000..1f0c6bf5 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_solaris.go @@ -0,0 +1,16 @@ +// +build solaris +// +build !appengine + +package isatty + +import ( + "golang.org/x/sys/unix" +) + +// IsTerminal returns true if the given file descriptor is a terminal. +// see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c +func IsTerminal(fd uintptr) bool { + var termio unix.Termio + err := unix.IoctlSetTermio(int(fd), unix.TCGETA, &termio) + return err == nil +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_windows.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_windows.go new file mode 100644 index 00000000..83c398b1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_windows.go @@ -0,0 +1,19 @@ +// +build windows +// +build !appengine + +package isatty + +import ( + "syscall" + "unsafe" +) + +var kernel32 = syscall.NewLazyDLL("kernel32.dll") +var procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + +// IsTerminal return true if the file descriptor is terminal. +func IsTerminal(fd uintptr) bool { + var st uint32 + r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, fd, uintptr(unsafe.Pointer(&st)), 0) + return r != 0 && e == 0 +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/teamcity_reporter.go b/vendor/github.com/onsi/ginkgo/reporters/teamcity_reporter.go new file mode 100644 index 00000000..84fd8aff --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/reporters/teamcity_reporter.go @@ -0,0 +1,106 @@ +/* + +TeamCity Reporter for Ginkgo + +Makes use of TeamCity's support for Service Messages +http://confluence.jetbrains.com/display/TCD7/Build+Script+Interaction+with+TeamCity#BuildScriptInteractionwithTeamCity-ReportingTests +*/ + +package reporters + +import ( + "fmt" + "io" + "strings" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/types" +) + +const ( + messageId = "##teamcity" +) + +type TeamCityReporter struct { + writer io.Writer + testSuiteName string + ReporterConfig config.DefaultReporterConfigType +} + +func NewTeamCityReporter(writer io.Writer) *TeamCityReporter { + return &TeamCityReporter{ + writer: writer, + } +} + +func (reporter *TeamCityReporter) SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) { + reporter.testSuiteName = escape(summary.SuiteDescription) + fmt.Fprintf(reporter.writer, "%s[testSuiteStarted name='%s']\n", messageId, reporter.testSuiteName) +} + +func (reporter *TeamCityReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { + reporter.handleSetupSummary("BeforeSuite", setupSummary) +} + +func (reporter *TeamCityReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { + reporter.handleSetupSummary("AfterSuite", setupSummary) +} + +func (reporter *TeamCityReporter) handleSetupSummary(name string, setupSummary *types.SetupSummary) { + if setupSummary.State != types.SpecStatePassed { + testName := escape(name) + fmt.Fprintf(reporter.writer, "%s[testStarted name='%s']\n", messageId, testName) + message := reporter.failureMessage(setupSummary.Failure) + details := reporter.failureDetails(setupSummary.Failure) + fmt.Fprintf(reporter.writer, "%s[testFailed name='%s' message='%s' details='%s']\n", messageId, testName, message, details) + durationInMilliseconds := setupSummary.RunTime.Seconds() * 1000 + fmt.Fprintf(reporter.writer, "%s[testFinished name='%s' duration='%v']\n", messageId, testName, durationInMilliseconds) + } +} + +func (reporter *TeamCityReporter) SpecWillRun(specSummary *types.SpecSummary) { + testName := escape(strings.Join(specSummary.ComponentTexts[1:], " ")) + fmt.Fprintf(reporter.writer, "%s[testStarted name='%s']\n", messageId, testName) +} + +func (reporter *TeamCityReporter) SpecDidComplete(specSummary *types.SpecSummary) { + testName := escape(strings.Join(specSummary.ComponentTexts[1:], " ")) + + if reporter.ReporterConfig.ReportPassed && specSummary.State == types.SpecStatePassed { + details := escape(specSummary.CapturedOutput) + fmt.Fprintf(reporter.writer, "%s[testPassed name='%s' details='%s']\n", messageId, testName, details) + } + if specSummary.State == types.SpecStateFailed || specSummary.State == types.SpecStateTimedOut || specSummary.State == types.SpecStatePanicked { + message := reporter.failureMessage(specSummary.Failure) + details := reporter.failureDetails(specSummary.Failure) + fmt.Fprintf(reporter.writer, "%s[testFailed name='%s' message='%s' details='%s']\n", messageId, testName, message, details) + } + if specSummary.State == types.SpecStateSkipped || specSummary.State == types.SpecStatePending { + fmt.Fprintf(reporter.writer, "%s[testIgnored name='%s']\n", messageId, testName) + } + + durationInMilliseconds := specSummary.RunTime.Seconds() * 1000 + fmt.Fprintf(reporter.writer, "%s[testFinished name='%s' duration='%v']\n", messageId, testName, durationInMilliseconds) +} + +func (reporter *TeamCityReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { + fmt.Fprintf(reporter.writer, "%s[testSuiteFinished name='%s']\n", messageId, reporter.testSuiteName) +} + +func (reporter *TeamCityReporter) failureMessage(failure types.SpecFailure) string { + return escape(failure.ComponentCodeLocation.String()) +} + +func (reporter *TeamCityReporter) failureDetails(failure types.SpecFailure) string { + return escape(fmt.Sprintf("%s\n%s", failure.Message, failure.Location.String())) +} + +func escape(output string) string { + output = strings.Replace(output, "|", "||", -1) + output = strings.Replace(output, "'", "|'", -1) + output = strings.Replace(output, "\n", "|n", -1) + output = strings.Replace(output, "\r", "|r", -1) + output = strings.Replace(output, "[", "|[", -1) + output = strings.Replace(output, "]", "|]", -1) + return output +} diff --git a/vendor/github.com/onsi/ginkgo/types/code_location.go b/vendor/github.com/onsi/ginkgo/types/code_location.go new file mode 100644 index 00000000..935a89e1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/types/code_location.go @@ -0,0 +1,15 @@ +package types + +import ( + "fmt" +) + +type CodeLocation struct { + FileName string + LineNumber int + FullStackTrace string +} + +func (codeLocation CodeLocation) String() string { + return fmt.Sprintf("%s:%d", codeLocation.FileName, codeLocation.LineNumber) +} diff --git a/vendor/github.com/onsi/ginkgo/types/deprecation_support.go b/vendor/github.com/onsi/ginkgo/types/deprecation_support.go new file mode 100644 index 00000000..305c134b --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/types/deprecation_support.go @@ -0,0 +1,150 @@ +package types + +import ( + "os" + "strconv" + "strings" + "unicode" + + "github.com/onsi/ginkgo/config" + "github.com/onsi/ginkgo/formatter" +) + +type Deprecation struct { + Message string + DocLink string + Version string +} + +type deprecations struct{} + +var Deprecations = deprecations{} + +func (d deprecations) CustomReporter() Deprecation { + return Deprecation{ + Message: "You are using a custom reporter. Support for custom reporters will likely be removed in V2. Most users were using them to generate junit or teamcity reports and this functionality will be merged into the core reporter. In addition, Ginkgo 2.0 will support emitting a JSON-formatted report that users can then manipulate to generate custom reports.\n\n{{red}}{{bold}}If this change will be impactful to you please leave a comment on {{cyan}}{{underline}}https://github.com/onsi/ginkgo/issues/711{{/}}", + DocLink: "removed-custom-reporters", + Version: "1.16.0", + } +} + +func (d deprecations) V1Reporter() Deprecation { + return Deprecation{ + Message: "You are using a V1 Ginkgo Reporter. Please update your custom reporter to the new V2 Reporter interface.", + DocLink: "changed-reporter-interface", + Version: "1.16.0", + } +} + +func (d deprecations) Async() Deprecation { + return Deprecation{ + Message: "You are passing a Done channel to a test node to test asynchronous behavior. This is deprecated in Ginkgo V2. Your test will run synchronously and the timeout will be ignored.", + DocLink: "removed-async-testing", + Version: "1.16.0", + } +} + +func (d deprecations) Measure() Deprecation { + return Deprecation{ + Message: "Measure is deprecated and will be removed in Ginkgo V2. Please migrate to gomega/gmeasure.", + DocLink: "removed-measure", + Version: "1.16.3", + } +} + +func (d deprecations) Convert() Deprecation { + return Deprecation{ + Message: "The convert command is deprecated in Ginkgo V2", + DocLink: "removed-ginkgo-convert", + Version: "1.16.0", + } +} + +func (d deprecations) Blur() Deprecation { + return Deprecation{ + Message: "The blur command is deprecated in Ginkgo V2. Use 'ginkgo unfocus' instead.", + Version: "1.16.0", + } +} + +type DeprecationTracker struct { + deprecations map[Deprecation][]CodeLocation +} + +func NewDeprecationTracker() *DeprecationTracker { + return &DeprecationTracker{ + deprecations: map[Deprecation][]CodeLocation{}, + } +} + +func (d *DeprecationTracker) TrackDeprecation(deprecation Deprecation, cl ...CodeLocation) { + ackVersion := os.Getenv("ACK_GINKGO_DEPRECATIONS") + if deprecation.Version != "" && ackVersion != "" { + ack := ParseSemVer(ackVersion) + version := ParseSemVer(deprecation.Version) + if ack.GreaterThanOrEqualTo(version) { + return + } + } + + if len(cl) == 1 { + d.deprecations[deprecation] = append(d.deprecations[deprecation], cl[0]) + } else { + d.deprecations[deprecation] = []CodeLocation{} + } +} + +func (d *DeprecationTracker) DidTrackDeprecations() bool { + return len(d.deprecations) > 0 +} + +func (d *DeprecationTracker) DeprecationsReport() string { + out := formatter.F("{{light-yellow}}You're using deprecated Ginkgo functionality:{{/}}\n") + out += formatter.F("{{light-yellow}}============================================={{/}}\n") + out += formatter.F("Ginkgo 2.0 is under active development and will introduce (a small number of) breaking changes.\n") + out += formatter.F("To learn more, view the migration guide at {{cyan}}{{underline}}https://github.com/onsi/ginkgo/blob/v2/docs/MIGRATING_TO_V2.md{{/}}\n") + out += formatter.F("To comment, chime in at {{cyan}}{{underline}}https://github.com/onsi/ginkgo/issues/711{{/}}\n\n") + + for deprecation, locations := range d.deprecations { + out += formatter.Fi(1, "{{yellow}}"+deprecation.Message+"{{/}}\n") + if deprecation.DocLink != "" { + out += formatter.Fi(1, "{{bold}}Learn more at:{{/}} {{cyan}}{{underline}}https://github.com/onsi/ginkgo/blob/v2/docs/MIGRATING_TO_V2.md#%s{{/}}\n", deprecation.DocLink) + } + for _, location := range locations { + out += formatter.Fi(2, "{{gray}}%s{{/}}\n", location) + } + } + out += formatter.F("\n{{gray}}To silence deprecations that can be silenced set the following environment variable:{{/}}\n") + out += formatter.Fi(1, "{{gray}}ACK_GINKGO_DEPRECATIONS=%s{{/}}\n", config.VERSION) + return out +} + +type SemVer struct { + Major int + Minor int + Patch int +} + +func (s SemVer) GreaterThanOrEqualTo(o SemVer) bool { + return (s.Major > o.Major) || + (s.Major == o.Major && s.Minor > o.Minor) || + (s.Major == o.Major && s.Minor == o.Minor && s.Patch >= o.Patch) +} + +func ParseSemVer(semver string) SemVer { + out := SemVer{} + semver = strings.TrimFunc(semver, func(r rune) bool { + return !(unicode.IsNumber(r) || r == '.') + }) + components := strings.Split(semver, ".") + if len(components) > 0 { + out.Major, _ = strconv.Atoi(components[0]) + } + if len(components) > 1 { + out.Minor, _ = strconv.Atoi(components[1]) + } + if len(components) > 2 { + out.Patch, _ = strconv.Atoi(components[2]) + } + return out +} diff --git a/vendor/github.com/onsi/ginkgo/types/synchronization.go b/vendor/github.com/onsi/ginkgo/types/synchronization.go new file mode 100644 index 00000000..fdd6ed5b --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/types/synchronization.go @@ -0,0 +1,30 @@ +package types + +import ( + "encoding/json" +) + +type RemoteBeforeSuiteState int + +const ( + RemoteBeforeSuiteStateInvalid RemoteBeforeSuiteState = iota + + RemoteBeforeSuiteStatePending + RemoteBeforeSuiteStatePassed + RemoteBeforeSuiteStateFailed + RemoteBeforeSuiteStateDisappeared +) + +type RemoteBeforeSuiteData struct { + Data []byte + State RemoteBeforeSuiteState +} + +func (r RemoteBeforeSuiteData) ToJSON() []byte { + data, _ := json.Marshal(r) + return data +} + +type RemoteAfterSuiteData struct { + CanRun bool +} diff --git a/vendor/github.com/onsi/ginkgo/types/types.go b/vendor/github.com/onsi/ginkgo/types/types.go new file mode 100644 index 00000000..c143e02d --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/types/types.go @@ -0,0 +1,174 @@ +package types + +import ( + "strconv" + "time" +) + +const GINKGO_FOCUS_EXIT_CODE = 197 + +/* +SuiteSummary represents the a summary of the test suite and is passed to both +Reporter.SpecSuiteWillBegin +Reporter.SpecSuiteDidEnd + +this is unfortunate as these two methods should receive different objects. When running in parallel +each node does not deterministically know how many specs it will end up running. + +Unfortunately making such a change would break backward compatibility. + +Until Ginkgo 2.0 comes out we will continue to reuse this struct but populate unknown fields +with -1. +*/ +type SuiteSummary struct { + SuiteDescription string + SuiteSucceeded bool + SuiteID string + + NumberOfSpecsBeforeParallelization int + NumberOfTotalSpecs int + NumberOfSpecsThatWillBeRun int + NumberOfPendingSpecs int + NumberOfSkippedSpecs int + NumberOfPassedSpecs int + NumberOfFailedSpecs int + // Flaked specs are those that failed initially, but then passed on a + // subsequent try. + NumberOfFlakedSpecs int + RunTime time.Duration +} + +type SpecSummary struct { + ComponentTexts []string + ComponentCodeLocations []CodeLocation + + State SpecState + RunTime time.Duration + Failure SpecFailure + IsMeasurement bool + NumberOfSamples int + Measurements map[string]*SpecMeasurement + + CapturedOutput string + SuiteID string +} + +func (s SpecSummary) HasFailureState() bool { + return s.State.IsFailure() +} + +func (s SpecSummary) TimedOut() bool { + return s.State == SpecStateTimedOut +} + +func (s SpecSummary) Panicked() bool { + return s.State == SpecStatePanicked +} + +func (s SpecSummary) Failed() bool { + return s.State == SpecStateFailed +} + +func (s SpecSummary) Passed() bool { + return s.State == SpecStatePassed +} + +func (s SpecSummary) Skipped() bool { + return s.State == SpecStateSkipped +} + +func (s SpecSummary) Pending() bool { + return s.State == SpecStatePending +} + +type SetupSummary struct { + ComponentType SpecComponentType + CodeLocation CodeLocation + + State SpecState + RunTime time.Duration + Failure SpecFailure + + CapturedOutput string + SuiteID string +} + +type SpecFailure struct { + Message string + Location CodeLocation + ForwardedPanic string + + ComponentIndex int + ComponentType SpecComponentType + ComponentCodeLocation CodeLocation +} + +type SpecMeasurement struct { + Name string + Info interface{} + Order int + + Results []float64 + + Smallest float64 + Largest float64 + Average float64 + StdDeviation float64 + + SmallestLabel string + LargestLabel string + AverageLabel string + Units string + Precision int +} + +func (s SpecMeasurement) PrecisionFmt() string { + if s.Precision == 0 { + return "%f" + } + + str := strconv.Itoa(s.Precision) + + return "%." + str + "f" +} + +type SpecState uint + +const ( + SpecStateInvalid SpecState = iota + + SpecStatePending + SpecStateSkipped + SpecStatePassed + SpecStateFailed + SpecStatePanicked + SpecStateTimedOut +) + +func (state SpecState) IsFailure() bool { + return state == SpecStateTimedOut || state == SpecStatePanicked || state == SpecStateFailed +} + +type SpecComponentType uint + +const ( + SpecComponentTypeInvalid SpecComponentType = iota + + SpecComponentTypeContainer + SpecComponentTypeBeforeSuite + SpecComponentTypeAfterSuite + SpecComponentTypeBeforeEach + SpecComponentTypeJustBeforeEach + SpecComponentTypeJustAfterEach + SpecComponentTypeAfterEach + SpecComponentTypeIt + SpecComponentTypeMeasure +) + +type FlagType uint + +const ( + FlagTypeNone FlagType = iota + FlagTypeFocused + FlagTypePending +) diff --git a/vendor/golang.org/x/crypto/cryptobyte/asn1.go b/vendor/golang.org/x/crypto/cryptobyte/asn1.go new file mode 100644 index 00000000..3a1674a1 --- /dev/null +++ b/vendor/golang.org/x/crypto/cryptobyte/asn1.go @@ -0,0 +1,809 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cryptobyte + +import ( + encoding_asn1 "encoding/asn1" + "fmt" + "math/big" + "reflect" + "time" + + "golang.org/x/crypto/cryptobyte/asn1" +) + +// This file contains ASN.1-related methods for String and Builder. + +// Builder + +// AddASN1Int64 appends a DER-encoded ASN.1 INTEGER. +func (b *Builder) AddASN1Int64(v int64) { + b.addASN1Signed(asn1.INTEGER, v) +} + +// AddASN1Int64WithTag appends a DER-encoded ASN.1 INTEGER with the +// given tag. +func (b *Builder) AddASN1Int64WithTag(v int64, tag asn1.Tag) { + b.addASN1Signed(tag, v) +} + +// AddASN1Enum appends a DER-encoded ASN.1 ENUMERATION. +func (b *Builder) AddASN1Enum(v int64) { + b.addASN1Signed(asn1.ENUM, v) +} + +func (b *Builder) addASN1Signed(tag asn1.Tag, v int64) { + b.AddASN1(tag, func(c *Builder) { + length := 1 + for i := v; i >= 0x80 || i < -0x80; i >>= 8 { + length++ + } + + for ; length > 0; length-- { + i := v >> uint((length-1)*8) & 0xff + c.AddUint8(uint8(i)) + } + }) +} + +// AddASN1Uint64 appends a DER-encoded ASN.1 INTEGER. +func (b *Builder) AddASN1Uint64(v uint64) { + b.AddASN1(asn1.INTEGER, func(c *Builder) { + length := 1 + for i := v; i >= 0x80; i >>= 8 { + length++ + } + + for ; length > 0; length-- { + i := v >> uint((length-1)*8) & 0xff + c.AddUint8(uint8(i)) + } + }) +} + +// AddASN1BigInt appends a DER-encoded ASN.1 INTEGER. +func (b *Builder) AddASN1BigInt(n *big.Int) { + if b.err != nil { + return + } + + b.AddASN1(asn1.INTEGER, func(c *Builder) { + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement form. So we + // invert and subtract 1. If the most-significant-bit isn't set then + // we'll need to pad the beginning with 0xff in order to keep the number + // negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + c.add(0xff) + } + c.add(bytes...) + } else if n.Sign() == 0 { + c.add(0) + } else { + bytes := n.Bytes() + if bytes[0]&0x80 != 0 { + c.add(0) + } + c.add(bytes...) + } + }) +} + +// AddASN1OctetString appends a DER-encoded ASN.1 OCTET STRING. +func (b *Builder) AddASN1OctetString(bytes []byte) { + b.AddASN1(asn1.OCTET_STRING, func(c *Builder) { + c.AddBytes(bytes) + }) +} + +const generalizedTimeFormatStr = "20060102150405Z0700" + +// AddASN1GeneralizedTime appends a DER-encoded ASN.1 GENERALIZEDTIME. +func (b *Builder) AddASN1GeneralizedTime(t time.Time) { + if t.Year() < 0 || t.Year() > 9999 { + b.err = fmt.Errorf("cryptobyte: cannot represent %v as a GeneralizedTime", t) + return + } + b.AddASN1(asn1.GeneralizedTime, func(c *Builder) { + c.AddBytes([]byte(t.Format(generalizedTimeFormatStr))) + }) +} + +// AddASN1UTCTime appends a DER-encoded ASN.1 UTCTime. +func (b *Builder) AddASN1UTCTime(t time.Time) { + b.AddASN1(asn1.UTCTime, func(c *Builder) { + // As utilized by the X.509 profile, UTCTime can only + // represent the years 1950 through 2049. + if t.Year() < 1950 || t.Year() >= 2050 { + b.err = fmt.Errorf("cryptobyte: cannot represent %v as a UTCTime", t) + return + } + c.AddBytes([]byte(t.Format(defaultUTCTimeFormatStr))) + }) +} + +// AddASN1BitString appends a DER-encoded ASN.1 BIT STRING. This does not +// support BIT STRINGs that are not a whole number of bytes. +func (b *Builder) AddASN1BitString(data []byte) { + b.AddASN1(asn1.BIT_STRING, func(b *Builder) { + b.AddUint8(0) + b.AddBytes(data) + }) +} + +func (b *Builder) addBase128Int(n int64) { + var length int + if n == 0 { + length = 1 + } else { + for i := n; i > 0; i >>= 7 { + length++ + } + } + + for i := length - 1; i >= 0; i-- { + o := byte(n >> uint(i*7)) + o &= 0x7f + if i != 0 { + o |= 0x80 + } + + b.add(o) + } +} + +func isValidOID(oid encoding_asn1.ObjectIdentifier) bool { + if len(oid) < 2 { + return false + } + + if oid[0] > 2 || (oid[0] <= 1 && oid[1] >= 40) { + return false + } + + for _, v := range oid { + if v < 0 { + return false + } + } + + return true +} + +func (b *Builder) AddASN1ObjectIdentifier(oid encoding_asn1.ObjectIdentifier) { + b.AddASN1(asn1.OBJECT_IDENTIFIER, func(b *Builder) { + if !isValidOID(oid) { + b.err = fmt.Errorf("cryptobyte: invalid OID: %v", oid) + return + } + + b.addBase128Int(int64(oid[0])*40 + int64(oid[1])) + for _, v := range oid[2:] { + b.addBase128Int(int64(v)) + } + }) +} + +func (b *Builder) AddASN1Boolean(v bool) { + b.AddASN1(asn1.BOOLEAN, func(b *Builder) { + if v { + b.AddUint8(0xff) + } else { + b.AddUint8(0) + } + }) +} + +func (b *Builder) AddASN1NULL() { + b.add(uint8(asn1.NULL), 0) +} + +// MarshalASN1 calls encoding_asn1.Marshal on its input and appends the result if +// successful or records an error if one occurred. +func (b *Builder) MarshalASN1(v interface{}) { + // NOTE(martinkr): This is somewhat of a hack to allow propagation of + // encoding_asn1.Marshal errors into Builder.err. N.B. if you call MarshalASN1 with a + // value embedded into a struct, its tag information is lost. + if b.err != nil { + return + } + bytes, err := encoding_asn1.Marshal(v) + if err != nil { + b.err = err + return + } + b.AddBytes(bytes) +} + +// AddASN1 appends an ASN.1 object. The object is prefixed with the given tag. +// Tags greater than 30 are not supported and result in an error (i.e. +// low-tag-number form only). The child builder passed to the +// BuilderContinuation can be used to build the content of the ASN.1 object. +func (b *Builder) AddASN1(tag asn1.Tag, f BuilderContinuation) { + if b.err != nil { + return + } + // Identifiers with the low five bits set indicate high-tag-number format + // (two or more octets), which we don't support. + if tag&0x1f == 0x1f { + b.err = fmt.Errorf("cryptobyte: high-tag number identifier octects not supported: 0x%x", tag) + return + } + b.AddUint8(uint8(tag)) + b.addLengthPrefixed(1, true, f) +} + +// String + +// ReadASN1Boolean decodes an ASN.1 BOOLEAN and converts it to a boolean +// representation into out and advances. It reports whether the read +// was successful. +func (s *String) ReadASN1Boolean(out *bool) bool { + var bytes String + if !s.ReadASN1(&bytes, asn1.BOOLEAN) || len(bytes) != 1 { + return false + } + + switch bytes[0] { + case 0: + *out = false + case 0xff: + *out = true + default: + return false + } + + return true +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem() + +// ReadASN1Integer decodes an ASN.1 INTEGER into out and advances. If out does +// not point to an integer or to a big.Int, it panics. It reports whether the +// read was successful. +func (s *String) ReadASN1Integer(out interface{}) bool { + if reflect.TypeOf(out).Kind() != reflect.Ptr { + panic("out is not a pointer") + } + switch reflect.ValueOf(out).Elem().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var i int64 + if !s.readASN1Int64(&i) || reflect.ValueOf(out).Elem().OverflowInt(i) { + return false + } + reflect.ValueOf(out).Elem().SetInt(i) + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var u uint64 + if !s.readASN1Uint64(&u) || reflect.ValueOf(out).Elem().OverflowUint(u) { + return false + } + reflect.ValueOf(out).Elem().SetUint(u) + return true + case reflect.Struct: + if reflect.TypeOf(out).Elem() == bigIntType { + return s.readASN1BigInt(out.(*big.Int)) + } + } + panic("out does not point to an integer type") +} + +func checkASN1Integer(bytes []byte) bool { + if len(bytes) == 0 { + // An INTEGER is encoded with at least one octet. + return false + } + if len(bytes) == 1 { + return true + } + if bytes[0] == 0 && bytes[1]&0x80 == 0 || bytes[0] == 0xff && bytes[1]&0x80 == 0x80 { + // Value is not minimally encoded. + return false + } + return true +} + +var bigOne = big.NewInt(1) + +func (s *String) readASN1BigInt(out *big.Int) bool { + var bytes String + if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) { + return false + } + if bytes[0]&0x80 == 0x80 { + // Negative number. + neg := make([]byte, len(bytes)) + for i, b := range bytes { + neg[i] = ^b + } + out.SetBytes(neg) + out.Add(out, bigOne) + out.Neg(out) + } else { + out.SetBytes(bytes) + } + return true +} + +func (s *String) readASN1Int64(out *int64) bool { + var bytes String + if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) || !asn1Signed(out, bytes) { + return false + } + return true +} + +func asn1Signed(out *int64, n []byte) bool { + length := len(n) + if length > 8 { + return false + } + for i := 0; i < length; i++ { + *out <<= 8 + *out |= int64(n[i]) + } + // Shift up and down in order to sign extend the result. + *out <<= 64 - uint8(length)*8 + *out >>= 64 - uint8(length)*8 + return true +} + +func (s *String) readASN1Uint64(out *uint64) bool { + var bytes String + if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) || !asn1Unsigned(out, bytes) { + return false + } + return true +} + +func asn1Unsigned(out *uint64, n []byte) bool { + length := len(n) + if length > 9 || length == 9 && n[0] != 0 { + // Too large for uint64. + return false + } + if n[0]&0x80 != 0 { + // Negative number. + return false + } + for i := 0; i < length; i++ { + *out <<= 8 + *out |= uint64(n[i]) + } + return true +} + +// ReadASN1Int64WithTag decodes an ASN.1 INTEGER with the given tag into out +// and advances. It reports whether the read was successful and resulted in a +// value that can be represented in an int64. +func (s *String) ReadASN1Int64WithTag(out *int64, tag asn1.Tag) bool { + var bytes String + return s.ReadASN1(&bytes, tag) && checkASN1Integer(bytes) && asn1Signed(out, bytes) +} + +// ReadASN1Enum decodes an ASN.1 ENUMERATION into out and advances. It reports +// whether the read was successful. +func (s *String) ReadASN1Enum(out *int) bool { + var bytes String + var i int64 + if !s.ReadASN1(&bytes, asn1.ENUM) || !checkASN1Integer(bytes) || !asn1Signed(&i, bytes) { + return false + } + if int64(int(i)) != i { + return false + } + *out = int(i) + return true +} + +func (s *String) readBase128Int(out *int) bool { + ret := 0 + for i := 0; len(*s) > 0; i++ { + if i == 5 { + return false + } + // Avoid overflowing int on a 32-bit platform. + // We don't want different behavior based on the architecture. + if ret >= 1<<(31-7) { + return false + } + ret <<= 7 + b := s.read(1)[0] + ret |= int(b & 0x7f) + if b&0x80 == 0 { + *out = ret + return true + } + } + return false // truncated +} + +// ReadASN1ObjectIdentifier decodes an ASN.1 OBJECT IDENTIFIER into out and +// advances. It reports whether the read was successful. +func (s *String) ReadASN1ObjectIdentifier(out *encoding_asn1.ObjectIdentifier) bool { + var bytes String + if !s.ReadASN1(&bytes, asn1.OBJECT_IDENTIFIER) || len(bytes) == 0 { + return false + } + + // In the worst case, we get two elements from the first byte (which is + // encoded differently) and then every varint is a single byte long. + components := make([]int, len(bytes)+1) + + // The first varint is 40*value1 + value2: + // According to this packing, value1 can take the values 0, 1 and 2 only. + // When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2, + // then there are no restrictions on value2. + var v int + if !bytes.readBase128Int(&v) { + return false + } + if v < 80 { + components[0] = v / 40 + components[1] = v % 40 + } else { + components[0] = 2 + components[1] = v - 80 + } + + i := 2 + for ; len(bytes) > 0; i++ { + if !bytes.readBase128Int(&v) { + return false + } + components[i] = v + } + *out = components[:i] + return true +} + +// ReadASN1GeneralizedTime decodes an ASN.1 GENERALIZEDTIME into out and +// advances. It reports whether the read was successful. +func (s *String) ReadASN1GeneralizedTime(out *time.Time) bool { + var bytes String + if !s.ReadASN1(&bytes, asn1.GeneralizedTime) { + return false + } + t := string(bytes) + res, err := time.Parse(generalizedTimeFormatStr, t) + if err != nil { + return false + } + if serialized := res.Format(generalizedTimeFormatStr); serialized != t { + return false + } + *out = res + return true +} + +const defaultUTCTimeFormatStr = "060102150405Z0700" + +// ReadASN1UTCTime decodes an ASN.1 UTCTime into out and advances. +// It reports whether the read was successful. +func (s *String) ReadASN1UTCTime(out *time.Time) bool { + var bytes String + if !s.ReadASN1(&bytes, asn1.UTCTime) { + return false + } + t := string(bytes) + + formatStr := defaultUTCTimeFormatStr + var err error + res, err := time.Parse(formatStr, t) + if err != nil { + // Fallback to minute precision if we can't parse second + // precision. If we are following X.509 or X.690 we shouldn't + // support this, but we do. + formatStr = "0601021504Z0700" + res, err = time.Parse(formatStr, t) + } + if err != nil { + return false + } + + if serialized := res.Format(formatStr); serialized != t { + return false + } + + if res.Year() >= 2050 { + // UTCTime interprets the low order digits 50-99 as 1950-99. + // This only applies to its use in the X.509 profile. + // See https://tools.ietf.org/html/rfc5280#section-4.1.2.5.1 + res = res.AddDate(-100, 0, 0) + } + *out = res + return true +} + +// ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances. +// It reports whether the read was successful. +func (s *String) ReadASN1BitString(out *encoding_asn1.BitString) bool { + var bytes String + if !s.ReadASN1(&bytes, asn1.BIT_STRING) || len(bytes) == 0 || + len(bytes)*8/8 != len(bytes) { + return false + } + + paddingBits := uint8(bytes[0]) + bytes = bytes[1:] + if paddingBits > 7 || + len(bytes) == 0 && paddingBits != 0 || + len(bytes) > 0 && bytes[len(bytes)-1]&(1< 4 || len(*s) < int(2+lenLen) { + return false + } + + lenBytes := String((*s)[2 : 2+lenLen]) + if !lenBytes.readUnsigned(&len32, int(lenLen)) { + return false + } + + // ITU-T X.690 section 10.1 (DER length forms) requires encoding the length + // with the minimum number of octets. + if len32 < 128 { + // Length should have used short-form encoding. + return false + } + if len32>>((lenLen-1)*8) == 0 { + // Leading octet is 0. Length should have been at least one byte shorter. + return false + } + + headerLen = 2 + uint32(lenLen) + if headerLen+len32 < len32 { + // Overflow. + return false + } + length = headerLen + len32 + } + + if int(length) < 0 || !s.ReadBytes((*[]byte)(out), int(length)) { + return false + } + if skipHeader && !out.Skip(int(headerLen)) { + panic("cryptobyte: internal error") + } + + return true +} diff --git a/vendor/golang.org/x/crypto/cryptobyte/asn1/asn1.go b/vendor/golang.org/x/crypto/cryptobyte/asn1/asn1.go new file mode 100644 index 00000000..cda8e3ed --- /dev/null +++ b/vendor/golang.org/x/crypto/cryptobyte/asn1/asn1.go @@ -0,0 +1,46 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package asn1 contains supporting types for parsing and building ASN.1 +// messages with the cryptobyte package. +package asn1 // import "golang.org/x/crypto/cryptobyte/asn1" + +// Tag represents an ASN.1 identifier octet, consisting of a tag number +// (indicating a type) and class (such as context-specific or constructed). +// +// Methods in the cryptobyte package only support the low-tag-number form, i.e. +// a single identifier octet with bits 7-8 encoding the class and bits 1-6 +// encoding the tag number. +type Tag uint8 + +const ( + classConstructed = 0x20 + classContextSpecific = 0x80 +) + +// Constructed returns t with the constructed class bit set. +func (t Tag) Constructed() Tag { return t | classConstructed } + +// ContextSpecific returns t with the context-specific class bit set. +func (t Tag) ContextSpecific() Tag { return t | classContextSpecific } + +// The following is a list of standard tag and class combinations. +const ( + BOOLEAN = Tag(1) + INTEGER = Tag(2) + BIT_STRING = Tag(3) + OCTET_STRING = Tag(4) + NULL = Tag(5) + OBJECT_IDENTIFIER = Tag(6) + ENUM = Tag(10) + UTF8String = Tag(12) + SEQUENCE = Tag(16 | classConstructed) + SET = Tag(17 | classConstructed) + PrintableString = Tag(19) + T61String = Tag(20) + IA5String = Tag(22) + UTCTime = Tag(23) + GeneralizedTime = Tag(24) + GeneralString = Tag(27) +) diff --git a/vendor/golang.org/x/crypto/cryptobyte/builder.go b/vendor/golang.org/x/crypto/cryptobyte/builder.go new file mode 100644 index 00000000..c7ded757 --- /dev/null +++ b/vendor/golang.org/x/crypto/cryptobyte/builder.go @@ -0,0 +1,337 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cryptobyte + +import ( + "errors" + "fmt" +) + +// A Builder builds byte strings from fixed-length and length-prefixed values. +// Builders either allocate space as needed, or are ‘fixed’, which means that +// they write into a given buffer and produce an error if it's exhausted. +// +// The zero value is a usable Builder that allocates space as needed. +// +// Simple values are marshaled and appended to a Builder using methods on the +// Builder. Length-prefixed values are marshaled by providing a +// BuilderContinuation, which is a function that writes the inner contents of +// the value to a given Builder. See the documentation for BuilderContinuation +// for details. +type Builder struct { + err error + result []byte + fixedSize bool + child *Builder + offset int + pendingLenLen int + pendingIsASN1 bool + inContinuation *bool +} + +// NewBuilder creates a Builder that appends its output to the given buffer. +// Like append(), the slice will be reallocated if its capacity is exceeded. +// Use Bytes to get the final buffer. +func NewBuilder(buffer []byte) *Builder { + return &Builder{ + result: buffer, + } +} + +// NewFixedBuilder creates a Builder that appends its output into the given +// buffer. This builder does not reallocate the output buffer. Writes that +// would exceed the buffer's capacity are treated as an error. +func NewFixedBuilder(buffer []byte) *Builder { + return &Builder{ + result: buffer, + fixedSize: true, + } +} + +// SetError sets the value to be returned as the error from Bytes. Writes +// performed after calling SetError are ignored. +func (b *Builder) SetError(err error) { + b.err = err +} + +// Bytes returns the bytes written by the builder or an error if one has +// occurred during building. +func (b *Builder) Bytes() ([]byte, error) { + if b.err != nil { + return nil, b.err + } + return b.result[b.offset:], nil +} + +// BytesOrPanic returns the bytes written by the builder or panics if an error +// has occurred during building. +func (b *Builder) BytesOrPanic() []byte { + if b.err != nil { + panic(b.err) + } + return b.result[b.offset:] +} + +// AddUint8 appends an 8-bit value to the byte string. +func (b *Builder) AddUint8(v uint8) { + b.add(byte(v)) +} + +// AddUint16 appends a big-endian, 16-bit value to the byte string. +func (b *Builder) AddUint16(v uint16) { + b.add(byte(v>>8), byte(v)) +} + +// AddUint24 appends a big-endian, 24-bit value to the byte string. The highest +// byte of the 32-bit input value is silently truncated. +func (b *Builder) AddUint24(v uint32) { + b.add(byte(v>>16), byte(v>>8), byte(v)) +} + +// AddUint32 appends a big-endian, 32-bit value to the byte string. +func (b *Builder) AddUint32(v uint32) { + b.add(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +// AddBytes appends a sequence of bytes to the byte string. +func (b *Builder) AddBytes(v []byte) { + b.add(v...) +} + +// BuilderContinuation is a continuation-passing interface for building +// length-prefixed byte sequences. Builder methods for length-prefixed +// sequences (AddUint8LengthPrefixed etc) will invoke the BuilderContinuation +// supplied to them. The child builder passed to the continuation can be used +// to build the content of the length-prefixed sequence. For example: +// +// parent := cryptobyte.NewBuilder() +// parent.AddUint8LengthPrefixed(func (child *Builder) { +// child.AddUint8(42) +// child.AddUint8LengthPrefixed(func (grandchild *Builder) { +// grandchild.AddUint8(5) +// }) +// }) +// +// It is an error to write more bytes to the child than allowed by the reserved +// length prefix. After the continuation returns, the child must be considered +// invalid, i.e. users must not store any copies or references of the child +// that outlive the continuation. +// +// If the continuation panics with a value of type BuildError then the inner +// error will be returned as the error from Bytes. If the child panics +// otherwise then Bytes will repanic with the same value. +type BuilderContinuation func(child *Builder) + +// BuildError wraps an error. If a BuilderContinuation panics with this value, +// the panic will be recovered and the inner error will be returned from +// Builder.Bytes. +type BuildError struct { + Err error +} + +// AddUint8LengthPrefixed adds a 8-bit length-prefixed byte sequence. +func (b *Builder) AddUint8LengthPrefixed(f BuilderContinuation) { + b.addLengthPrefixed(1, false, f) +} + +// AddUint16LengthPrefixed adds a big-endian, 16-bit length-prefixed byte sequence. +func (b *Builder) AddUint16LengthPrefixed(f BuilderContinuation) { + b.addLengthPrefixed(2, false, f) +} + +// AddUint24LengthPrefixed adds a big-endian, 24-bit length-prefixed byte sequence. +func (b *Builder) AddUint24LengthPrefixed(f BuilderContinuation) { + b.addLengthPrefixed(3, false, f) +} + +// AddUint32LengthPrefixed adds a big-endian, 32-bit length-prefixed byte sequence. +func (b *Builder) AddUint32LengthPrefixed(f BuilderContinuation) { + b.addLengthPrefixed(4, false, f) +} + +func (b *Builder) callContinuation(f BuilderContinuation, arg *Builder) { + if !*b.inContinuation { + *b.inContinuation = true + + defer func() { + *b.inContinuation = false + + r := recover() + if r == nil { + return + } + + if buildError, ok := r.(BuildError); ok { + b.err = buildError.Err + } else { + panic(r) + } + }() + } + + f(arg) +} + +func (b *Builder) addLengthPrefixed(lenLen int, isASN1 bool, f BuilderContinuation) { + // Subsequent writes can be ignored if the builder has encountered an error. + if b.err != nil { + return + } + + offset := len(b.result) + b.add(make([]byte, lenLen)...) + + if b.inContinuation == nil { + b.inContinuation = new(bool) + } + + b.child = &Builder{ + result: b.result, + fixedSize: b.fixedSize, + offset: offset, + pendingLenLen: lenLen, + pendingIsASN1: isASN1, + inContinuation: b.inContinuation, + } + + b.callContinuation(f, b.child) + b.flushChild() + if b.child != nil { + panic("cryptobyte: internal error") + } +} + +func (b *Builder) flushChild() { + if b.child == nil { + return + } + b.child.flushChild() + child := b.child + b.child = nil + + if child.err != nil { + b.err = child.err + return + } + + length := len(child.result) - child.pendingLenLen - child.offset + + if length < 0 { + panic("cryptobyte: internal error") // result unexpectedly shrunk + } + + if child.pendingIsASN1 { + // For ASN.1, we reserved a single byte for the length. If that turned out + // to be incorrect, we have to move the contents along in order to make + // space. + if child.pendingLenLen != 1 { + panic("cryptobyte: internal error") + } + var lenLen, lenByte uint8 + if int64(length) > 0xfffffffe { + b.err = errors.New("pending ASN.1 child too long") + return + } else if length > 0xffffff { + lenLen = 5 + lenByte = 0x80 | 4 + } else if length > 0xffff { + lenLen = 4 + lenByte = 0x80 | 3 + } else if length > 0xff { + lenLen = 3 + lenByte = 0x80 | 2 + } else if length > 0x7f { + lenLen = 2 + lenByte = 0x80 | 1 + } else { + lenLen = 1 + lenByte = uint8(length) + length = 0 + } + + // Insert the initial length byte, make space for successive length bytes, + // and adjust the offset. + child.result[child.offset] = lenByte + extraBytes := int(lenLen - 1) + if extraBytes != 0 { + child.add(make([]byte, extraBytes)...) + childStart := child.offset + child.pendingLenLen + copy(child.result[childStart+extraBytes:], child.result[childStart:]) + } + child.offset++ + child.pendingLenLen = extraBytes + } + + l := length + for i := child.pendingLenLen - 1; i >= 0; i-- { + child.result[child.offset+i] = uint8(l) + l >>= 8 + } + if l != 0 { + b.err = fmt.Errorf("cryptobyte: pending child length %d exceeds %d-byte length prefix", length, child.pendingLenLen) + return + } + + if b.fixedSize && &b.result[0] != &child.result[0] { + panic("cryptobyte: BuilderContinuation reallocated a fixed-size buffer") + } + + b.result = child.result +} + +func (b *Builder) add(bytes ...byte) { + if b.err != nil { + return + } + if b.child != nil { + panic("cryptobyte: attempted write while child is pending") + } + if len(b.result)+len(bytes) < len(bytes) { + b.err = errors.New("cryptobyte: length overflow") + } + if b.fixedSize && len(b.result)+len(bytes) > cap(b.result) { + b.err = errors.New("cryptobyte: Builder is exceeding its fixed-size buffer") + return + } + b.result = append(b.result, bytes...) +} + +// Unwrite rolls back n bytes written directly to the Builder. An attempt by a +// child builder passed to a continuation to unwrite bytes from its parent will +// panic. +func (b *Builder) Unwrite(n int) { + if b.err != nil { + return + } + if b.child != nil { + panic("cryptobyte: attempted unwrite while child is pending") + } + length := len(b.result) - b.pendingLenLen - b.offset + if length < 0 { + panic("cryptobyte: internal error") + } + if n > length { + panic("cryptobyte: attempted to unwrite more than was written") + } + b.result = b.result[:len(b.result)-n] +} + +// A MarshalingValue marshals itself into a Builder. +type MarshalingValue interface { + // Marshal is called by Builder.AddValue. It receives a pointer to a builder + // to marshal itself into. It may return an error that occurred during + // marshaling, such as unset or invalid values. + Marshal(b *Builder) error +} + +// AddValue calls Marshal on v, passing a pointer to the builder to append to. +// If Marshal returns an error, it is set on the Builder so that subsequent +// appends don't have an effect. +func (b *Builder) AddValue(v MarshalingValue) { + err := v.Marshal(b) + if err != nil { + b.err = err + } +} diff --git a/vendor/golang.org/x/crypto/cryptobyte/string.go b/vendor/golang.org/x/crypto/cryptobyte/string.go new file mode 100644 index 00000000..589d297e --- /dev/null +++ b/vendor/golang.org/x/crypto/cryptobyte/string.go @@ -0,0 +1,161 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cryptobyte contains types that help with parsing and constructing +// length-prefixed, binary messages, including ASN.1 DER. (The asn1 subpackage +// contains useful ASN.1 constants.) +// +// The String type is for parsing. It wraps a []byte slice and provides helper +// functions for consuming structures, value by value. +// +// The Builder type is for constructing messages. It providers helper functions +// for appending values and also for appending length-prefixed submessages – +// without having to worry about calculating the length prefix ahead of time. +// +// See the documentation and examples for the Builder and String types to get +// started. +package cryptobyte // import "golang.org/x/crypto/cryptobyte" + +// String represents a string of bytes. It provides methods for parsing +// fixed-length and length-prefixed values from it. +type String []byte + +// read advances a String by n bytes and returns them. If less than n bytes +// remain, it returns nil. +func (s *String) read(n int) []byte { + if len(*s) < n || n < 0 { + return nil + } + v := (*s)[:n] + *s = (*s)[n:] + return v +} + +// Skip advances the String by n byte and reports whether it was successful. +func (s *String) Skip(n int) bool { + return s.read(n) != nil +} + +// ReadUint8 decodes an 8-bit value into out and advances over it. +// It reports whether the read was successful. +func (s *String) ReadUint8(out *uint8) bool { + v := s.read(1) + if v == nil { + return false + } + *out = uint8(v[0]) + return true +} + +// ReadUint16 decodes a big-endian, 16-bit value into out and advances over it. +// It reports whether the read was successful. +func (s *String) ReadUint16(out *uint16) bool { + v := s.read(2) + if v == nil { + return false + } + *out = uint16(v[0])<<8 | uint16(v[1]) + return true +} + +// ReadUint24 decodes a big-endian, 24-bit value into out and advances over it. +// It reports whether the read was successful. +func (s *String) ReadUint24(out *uint32) bool { + v := s.read(3) + if v == nil { + return false + } + *out = uint32(v[0])<<16 | uint32(v[1])<<8 | uint32(v[2]) + return true +} + +// ReadUint32 decodes a big-endian, 32-bit value into out and advances over it. +// It reports whether the read was successful. +func (s *String) ReadUint32(out *uint32) bool { + v := s.read(4) + if v == nil { + return false + } + *out = uint32(v[0])<<24 | uint32(v[1])<<16 | uint32(v[2])<<8 | uint32(v[3]) + return true +} + +func (s *String) readUnsigned(out *uint32, length int) bool { + v := s.read(length) + if v == nil { + return false + } + var result uint32 + for i := 0; i < length; i++ { + result <<= 8 + result |= uint32(v[i]) + } + *out = result + return true +} + +func (s *String) readLengthPrefixed(lenLen int, outChild *String) bool { + lenBytes := s.read(lenLen) + if lenBytes == nil { + return false + } + var length uint32 + for _, b := range lenBytes { + length = length << 8 + length = length | uint32(b) + } + v := s.read(int(length)) + if v == nil { + return false + } + *outChild = v + return true +} + +// ReadUint8LengthPrefixed reads the content of an 8-bit length-prefixed value +// into out and advances over it. It reports whether the read was successful. +func (s *String) ReadUint8LengthPrefixed(out *String) bool { + return s.readLengthPrefixed(1, out) +} + +// ReadUint16LengthPrefixed reads the content of a big-endian, 16-bit +// length-prefixed value into out and advances over it. It reports whether the +// read was successful. +func (s *String) ReadUint16LengthPrefixed(out *String) bool { + return s.readLengthPrefixed(2, out) +} + +// ReadUint24LengthPrefixed reads the content of a big-endian, 24-bit +// length-prefixed value into out and advances over it. It reports whether +// the read was successful. +func (s *String) ReadUint24LengthPrefixed(out *String) bool { + return s.readLengthPrefixed(3, out) +} + +// ReadBytes reads n bytes into out and advances over them. It reports +// whether the read was successful. +func (s *String) ReadBytes(out *[]byte, n int) bool { + v := s.read(n) + if v == nil { + return false + } + *out = v + return true +} + +// CopyBytes copies len(out) bytes into out and advances over them. It reports +// whether the copy operation was successful +func (s *String) CopyBytes(out []byte) bool { + n := len(out) + v := s.read(n) + if v == nil { + return false + } + return copy(out, v) == n +} + +// Empty reports whether the string does not contain any bytes. +func (s String) Empty() bool { + return len(s) == 0 +} diff --git a/vendor/gopkg.in/tomb.v1/LICENSE b/vendor/gopkg.in/tomb.v1/LICENSE new file mode 100644 index 00000000..a4249bb3 --- /dev/null +++ b/vendor/gopkg.in/tomb.v1/LICENSE @@ -0,0 +1,29 @@ +tomb - support for clean goroutine termination in Go. + +Copyright (c) 2010-2011 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/tomb.v1/README.md b/vendor/gopkg.in/tomb.v1/README.md new file mode 100644 index 00000000..3ae8788e --- /dev/null +++ b/vendor/gopkg.in/tomb.v1/README.md @@ -0,0 +1,4 @@ +Installation and usage +---------------------- + +See [gopkg.in/tomb.v1](https://gopkg.in/tomb.v1) for documentation and usage details. diff --git a/vendor/gopkg.in/tomb.v1/tomb.go b/vendor/gopkg.in/tomb.v1/tomb.go new file mode 100644 index 00000000..9aec56d8 --- /dev/null +++ b/vendor/gopkg.in/tomb.v1/tomb.go @@ -0,0 +1,176 @@ +// Copyright (c) 2011 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The tomb package offers a conventional API for clean goroutine termination. +// +// A Tomb tracks the lifecycle of a goroutine as alive, dying or dead, +// and the reason for its death. +// +// The zero value of a Tomb assumes that a goroutine is about to be +// created or already alive. Once Kill or Killf is called with an +// argument that informs the reason for death, the goroutine is in +// a dying state and is expected to terminate soon. Right before the +// goroutine function or method returns, Done must be called to inform +// that the goroutine is indeed dead and about to stop running. +// +// A Tomb exposes Dying and Dead channels. These channels are closed +// when the Tomb state changes in the respective way. They enable +// explicit blocking until the state changes, and also to selectively +// unblock select statements accordingly. +// +// When the tomb state changes to dying and there's still logic going +// on within the goroutine, nested functions and methods may choose to +// return ErrDying as their error value, as this error won't alter the +// tomb state if provided to the Kill method. This is a convenient way to +// follow standard Go practices in the context of a dying tomb. +// +// For background and a detailed example, see the following blog post: +// +// http://blog.labix.org/2011/10/09/death-of-goroutines-under-control +// +// For a more complex code snippet demonstrating the use of multiple +// goroutines with a single Tomb, see: +// +// http://play.golang.org/p/Xh7qWsDPZP +// +package tomb + +import ( + "errors" + "fmt" + "sync" +) + +// A Tomb tracks the lifecycle of a goroutine as alive, dying or dead, +// and the reason for its death. +// +// See the package documentation for details. +type Tomb struct { + m sync.Mutex + dying chan struct{} + dead chan struct{} + reason error +} + +var ( + ErrStillAlive = errors.New("tomb: still alive") + ErrDying = errors.New("tomb: dying") +) + +func (t *Tomb) init() { + t.m.Lock() + if t.dead == nil { + t.dead = make(chan struct{}) + t.dying = make(chan struct{}) + t.reason = ErrStillAlive + } + t.m.Unlock() +} + +// Dead returns the channel that can be used to wait +// until t.Done has been called. +func (t *Tomb) Dead() <-chan struct{} { + t.init() + return t.dead +} + +// Dying returns the channel that can be used to wait +// until t.Kill or t.Done has been called. +func (t *Tomb) Dying() <-chan struct{} { + t.init() + return t.dying +} + +// Wait blocks until the goroutine is in a dead state and returns the +// reason for its death. +func (t *Tomb) Wait() error { + t.init() + <-t.dead + t.m.Lock() + reason := t.reason + t.m.Unlock() + return reason +} + +// Done flags the goroutine as dead, and should be called a single time +// right before the goroutine function or method returns. +// If the goroutine was not already in a dying state before Done is +// called, it will be flagged as dying and dead at once with no +// error. +func (t *Tomb) Done() { + t.Kill(nil) + close(t.dead) +} + +// Kill flags the goroutine as dying for the given reason. +// Kill may be called multiple times, but only the first +// non-nil error is recorded as the reason for termination. +// +// If reason is ErrDying, the previous reason isn't replaced +// even if it is nil. It's a runtime error to call Kill with +// ErrDying if t is not in a dying state. +func (t *Tomb) Kill(reason error) { + t.init() + t.m.Lock() + defer t.m.Unlock() + if reason == ErrDying { + if t.reason == ErrStillAlive { + panic("tomb: Kill with ErrDying while still alive") + } + return + } + if t.reason == nil || t.reason == ErrStillAlive { + t.reason = reason + } + // If the receive on t.dying succeeds, then + // it can only be because we have already closed it. + // If it blocks, then we know that it needs to be closed. + select { + case <-t.dying: + default: + close(t.dying) + } +} + +// Killf works like Kill, but builds the reason providing the received +// arguments to fmt.Errorf. The generated error is also returned. +func (t *Tomb) Killf(f string, a ...interface{}) error { + err := fmt.Errorf(f, a...) + t.Kill(err) + return err +} + +// Err returns the reason for the goroutine death provided via Kill +// or Killf, or ErrStillAlive when the goroutine is still alive. +func (t *Tomb) Err() (reason error) { + t.init() + t.m.Lock() + reason = t.reason + t.m.Unlock() + return +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 7209daa7..83332bab 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -44,6 +44,12 @@ github.com/charithe/durationcheck # github.com/chavacava/garif v0.0.0-20210405164556-e8a0a408d6af ## explicit; go 1.16 github.com/chavacava/garif +# github.com/cheekybits/genny v1.0.0 +## explicit +github.com/cheekybits/genny +github.com/cheekybits/genny/generic +github.com/cheekybits/genny/out +github.com/cheekybits/genny/parse # github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf ## explicit github.com/coreos/go-systemd/activation @@ -84,6 +90,9 @@ github.com/go-critic/go-critic/checkers github.com/go-critic/go-critic/checkers/internal/astwalk github.com/go-critic/go-critic/checkers/internal/lintutil github.com/go-critic/go-critic/framework/linter +# github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 +## explicit; go 1.13 +github.com/go-task/slim-sprig # github.com/go-toolsmith/astcast v1.0.0 ## explicit github.com/go-toolsmith/astcast @@ -121,7 +130,7 @@ github.com/gobwas/glob/util/strings # github.com/gofrs/flock v0.8.0 ## explicit github.com/gofrs/flock -# github.com/golang/protobuf v1.5.0 +# github.com/golang/protobuf v1.5.2 ## explicit; go 1.9 github.com/golang/protobuf/proto github.com/golang/protobuf/ptypes @@ -306,12 +315,43 @@ github.com/ldez/gomoddirectives # github.com/ldez/tagliatelle v0.2.0 ## explicit; go 1.16 github.com/ldez/tagliatelle +# github.com/lucas-clemente/quic-go v0.28.0 +## explicit; go 1.16 +github.com/lucas-clemente/quic-go +github.com/lucas-clemente/quic-go/http3 +github.com/lucas-clemente/quic-go/internal/ackhandler +github.com/lucas-clemente/quic-go/internal/congestion +github.com/lucas-clemente/quic-go/internal/flowcontrol +github.com/lucas-clemente/quic-go/internal/handshake +github.com/lucas-clemente/quic-go/internal/logutils +github.com/lucas-clemente/quic-go/internal/protocol +github.com/lucas-clemente/quic-go/internal/qerr +github.com/lucas-clemente/quic-go/internal/qtls +github.com/lucas-clemente/quic-go/internal/utils +github.com/lucas-clemente/quic-go/internal/wire +github.com/lucas-clemente/quic-go/logging +github.com/lucas-clemente/quic-go/quicvarint # github.com/magiconair/properties v1.8.1 ## explicit github.com/magiconair/properties # github.com/maratori/testpackage v1.0.1 ## explicit; go 1.12 github.com/maratori/testpackage/pkg/testpackage +# github.com/marten-seemann/qpack v0.2.1 +## explicit; go 1.14 +github.com/marten-seemann/qpack +# github.com/marten-seemann/qtls-go1-16 v0.1.5 +## explicit; go 1.16 +github.com/marten-seemann/qtls-go1-16 +# github.com/marten-seemann/qtls-go1-17 v0.1.2 +## explicit; go 1.17 +github.com/marten-seemann/qtls-go1-17 +# github.com/marten-seemann/qtls-go1-18 v0.1.2 +## explicit; go 1.18 +github.com/marten-seemann/qtls-go1-18 +# github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 +## explicit; go 1.19 +github.com/marten-seemann/qtls-go1-19 # github.com/matoous/godox v0.0.0-20210227103229-6504466cf951 ## explicit; go 1.13 github.com/matoous/godox @@ -376,9 +416,41 @@ github.com/nishanths/exhaustive # github.com/nishanths/predeclared v0.2.1 ## explicit; go 1.14 github.com/nishanths/predeclared/passes/predeclared +# github.com/nxadm/tail v1.4.8 +## explicit; go 1.13 +github.com/nxadm/tail +github.com/nxadm/tail/ratelimiter +github.com/nxadm/tail/util +github.com/nxadm/tail/watch +github.com/nxadm/tail/winfile # github.com/olekukonko/tablewriter v0.0.5 ## explicit; go 1.12 github.com/olekukonko/tablewriter +# github.com/onsi/ginkgo v1.16.4 +## explicit; go 1.15 +github.com/onsi/ginkgo/config +github.com/onsi/ginkgo/formatter +github.com/onsi/ginkgo/ginkgo +github.com/onsi/ginkgo/ginkgo/convert +github.com/onsi/ginkgo/ginkgo/interrupthandler +github.com/onsi/ginkgo/ginkgo/nodot +github.com/onsi/ginkgo/ginkgo/outline +github.com/onsi/ginkgo/ginkgo/testrunner +github.com/onsi/ginkgo/ginkgo/testsuite +github.com/onsi/ginkgo/ginkgo/watch +github.com/onsi/ginkgo/internal/codelocation +github.com/onsi/ginkgo/internal/containernode +github.com/onsi/ginkgo/internal/failer +github.com/onsi/ginkgo/internal/leafnodes +github.com/onsi/ginkgo/internal/remote +github.com/onsi/ginkgo/internal/spec +github.com/onsi/ginkgo/internal/spec_iterator +github.com/onsi/ginkgo/internal/writer +github.com/onsi/ginkgo/reporters +github.com/onsi/ginkgo/reporters/stenographer +github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable +github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty +github.com/onsi/ginkgo/types # github.com/pelletier/go-toml v1.2.0 ## explicit github.com/pelletier/go-toml @@ -529,6 +601,8 @@ github.com/yeya24/promlinter golang.org/x/crypto/blake2b golang.org/x/crypto/chacha20 golang.org/x/crypto/chacha20poly1305 +golang.org/x/crypto/cryptobyte +golang.org/x/crypto/cryptobyte/asn1 golang.org/x/crypto/curve25519 golang.org/x/crypto/curve25519/internal/field golang.org/x/crypto/ed25519 @@ -699,6 +773,9 @@ gopkg.in/ini.v1 # gopkg.in/natefinch/lumberjack.v2 v2.0.0 ## explicit gopkg.in/natefinch/lumberjack.v2 +# gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 +## explicit +gopkg.in/tomb.v1 # gopkg.in/yaml.v2 v2.4.0 ## explicit; go 1.15 gopkg.in/yaml.v2