mirror of
https://github.com/DNSCrypt/dnscrypt-proxy.git
synced 2025-04-06 06:37:36 +03:00
Update quic-go dependency to support go 1.20 (#2292)
This commit is contained in:
parent
5438eed2f4
commit
c3fd855831
296 changed files with 14851 additions and 2397 deletions
17
vendor/github.com/quic-go/quic-go/.gitignore
generated
vendored
Normal file
17
vendor/github.com/quic-go/quic-go/.gitignore
generated
vendored
Normal file
|
@ -0,0 +1,17 @@
|
|||
debug
|
||||
debug.test
|
||||
main
|
||||
mockgen_tmp.go
|
||||
*.qtr
|
||||
*.qlog
|
||||
*.txt
|
||||
race.[0-9]*
|
||||
|
||||
fuzzing/*/*.zip
|
||||
fuzzing/*/coverprofile
|
||||
fuzzing/*/crashers
|
||||
fuzzing/*/sonarprofile
|
||||
fuzzing/*/suppressions
|
||||
fuzzing/*/corpus/
|
||||
|
||||
gomock_reflect_*/
|
44
vendor/github.com/quic-go/quic-go/.golangci.yml
generated
vendored
Normal file
44
vendor/github.com/quic-go/quic-go/.golangci.yml
generated
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
run:
|
||||
skip-files:
|
||||
- internal/qtls/structs_equal_test.go
|
||||
|
||||
linters-settings:
|
||||
depguard:
|
||||
type: blacklist
|
||||
packages:
|
||||
- github.com/marten-seemann/qtls
|
||||
packages-with-error-message:
|
||||
- github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls"
|
||||
misspell:
|
||||
ignore-words:
|
||||
- ect
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- asciicheck
|
||||
- deadcode
|
||||
- depguard
|
||||
- exhaustive
|
||||
- exportloopref
|
||||
- goimports
|
||||
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
|
||||
- gofumpt
|
||||
- gosimple
|
||||
- ineffassign
|
||||
- misspell
|
||||
- prealloc
|
||||
- staticcheck
|
||||
- stylecheck
|
||||
- structcheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- varcheck
|
||||
- vet
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
- path: internal/qtls
|
||||
linters:
|
||||
- depguard
|
109
vendor/github.com/quic-go/quic-go/Changelog.md
generated
vendored
Normal file
109
vendor/github.com/quic-go/quic-go/Changelog.md
generated
vendored
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Changelog
|
||||
|
||||
## v0.22.0 (2021-07-25)
|
||||
|
||||
- Use `ReadBatch` to read multiple UDP packets from the socket with a single syscall
|
||||
- Add a config option (`Config.DisableVersionNegotiationPackets`) to disable sending of Version Negotiation packets
|
||||
- Drop support for QUIC draft versions 32 and 34
|
||||
- Remove the `RetireBugBackwardsCompatibilityMode`, which was intended to mitigate a bug when retiring connection IDs in quic-go in v0.17.2 and ealier
|
||||
|
||||
## v0.21.2 (2021-07-15)
|
||||
|
||||
- Update qtls (for Go 1.15, 1.16 and 1.17rc1) to include the fix for the crypto/tls panic (see https://groups.google.com/g/golang-dev/c/5LJ2V7rd-Ag/m/YGLHVBZ6AAAJ for details)
|
||||
|
||||
## v0.21.0 (2021-06-01)
|
||||
|
||||
- quic-go now supports RFC 9000!
|
||||
|
||||
## v0.20.0 (2021-03-19)
|
||||
|
||||
- Remove the `quic.Config.HandshakeTimeout`. Introduce a `quic.Config.HandshakeIdleTimeout`.
|
||||
|
||||
## v0.17.1 (2020-06-20)
|
||||
|
||||
- Supports QUIC WG draft-29.
|
||||
- Improve bundling of ACK frames (#2543).
|
||||
|
||||
## v0.16.0 (2020-05-31)
|
||||
|
||||
- Supports QUIC WG draft-28.
|
||||
|
||||
## v0.15.0 (2020-03-01)
|
||||
|
||||
- Supports QUIC WG draft-27.
|
||||
- Add support for 0-RTT.
|
||||
- Remove `Session.Close()`. Applications need to pass an application error code to the transport using `Session.CloseWithError()`.
|
||||
- Make the TLS Cipher Suites configurable (via `tls.Config.CipherSuites`).
|
||||
|
||||
## v0.14.0 (2019-12-04)
|
||||
|
||||
- Supports QUIC WG draft-24.
|
||||
|
||||
## v0.13.0 (2019-11-05)
|
||||
|
||||
- Supports QUIC WG draft-23.
|
||||
- Add an `EarlyListener` that allows sending of 0.5-RTT data.
|
||||
- Add a `TokenStore` to store address validation tokens.
|
||||
- Issue and use new connection IDs during a connection.
|
||||
|
||||
## v0.12.0 (2019-08-05)
|
||||
|
||||
- Implement HTTP/3.
|
||||
- Rename `quic.Cookie` to `quic.Token` and `quic.Config.AcceptCookie` to `quic.Config.AcceptToken`.
|
||||
- Distinguish between Retry tokens and tokens sent in NEW_TOKEN frames.
|
||||
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
|
||||
- Use a varint for error codes.
|
||||
- Add support for [quic-trace](https://github.com/google/quic-trace).
|
||||
- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`.
|
||||
- Implement TLS key updates.
|
||||
|
||||
## v0.11.0 (2019-04-05)
|
||||
|
||||
- Drop support for gQUIC. For qQUIC support, please switch to the *gquic* branch.
|
||||
- Implement QUIC WG draft-19.
|
||||
- Use [qtls](https://github.com/marten-seemann/qtls) for TLS 1.3.
|
||||
- Return a `tls.ConnectionState` from `quic.Session.ConnectionState()`.
|
||||
- Remove the error return values from `quic.Stream.CancelRead()` and `quic.Stream.CancelWrite()`
|
||||
|
||||
## v0.10.0 (2018-08-28)
|
||||
|
||||
- Add support for QUIC 44, drop support for QUIC 42.
|
||||
|
||||
## v0.9.0 (2018-08-15)
|
||||
|
||||
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
|
||||
- Split Session.Close into one method for regular closing and one for closing with an error.
|
||||
|
||||
## v0.8.0 (2018-06-26)
|
||||
|
||||
- Add support for unidirectional streams (for IETF QUIC).
|
||||
- Add a `quic.Config` option for the maximum number of incoming streams.
|
||||
- Add support for QUIC 42 and 43.
|
||||
- Add dial functions that use a context.
|
||||
- Multiplex clients on a net.PacketConn, when using Dial(conn).
|
||||
|
||||
## v0.7.0 (2018-02-03)
|
||||
|
||||
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
|
||||
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
|
||||
- Expose the `ConnectionState` in the `Session` (experimental API).
|
||||
- Implement packet pacing.
|
||||
|
||||
## v0.6.0 (2017-12-12)
|
||||
|
||||
- Add support for QUIC 39, drop support for QUIC 35 - 37
|
||||
- Added `quic.Config` options for maximal flow control windows
|
||||
- Add a `quic.Config` option for QUIC versions
|
||||
- Add a `quic.Config` option to request omission of the connection ID from a server
|
||||
- Add a `quic.Config` option to configure the source address validation
|
||||
- Add a `quic.Config` option to configure the handshake timeout
|
||||
- Add a `quic.Config` option to configure the idle timeout
|
||||
- Add a `quic.Config` option to configure keep-alive
|
||||
- Rename the STK to Cookie
|
||||
- Implement `net.Conn`-style deadlines for streams
|
||||
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/quic-go/quic-go) for details.
|
||||
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/quic-go/quic-go/wiki/Logging) for more details.
|
||||
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
|
||||
- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn`
|
||||
- Drop support for Go 1.7 and 1.8.
|
||||
- Various bugfixes
|
21
vendor/github.com/quic-go/quic-go/LICENSE
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/LICENSE
generated
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2016 the quic-go authors & Google, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
63
vendor/github.com/quic-go/quic-go/README.md
generated
vendored
Normal file
63
vendor/github.com/quic-go/quic-go/README.md
generated
vendored
Normal file
|
@ -0,0 +1,63 @@
|
|||
# A QUIC implementation in pure Go
|
||||
|
||||
<img src="docs/quic.png" width=303 height=124>
|
||||
|
||||
[](https://pkg.go.dev/github.com/quic-go/quic-go)
|
||||
[](https://codecov.io/gh/quic-go/quic-go/)
|
||||
|
||||
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go, including the Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) and Datagram Packetization Layer Path MTU
|
||||
Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)). It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
|
||||
|
||||
In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem.
|
||||
|
||||
## Guides
|
||||
|
||||
*We currently support Go 1.18.x and Go 1.19.x.*
|
||||
|
||||
Running tests:
|
||||
|
||||
go test ./...
|
||||
|
||||
### QUIC without HTTP/3
|
||||
|
||||
Take a look at [this echo example](example/echo/echo.go).
|
||||
|
||||
## Usage
|
||||
|
||||
### As a server
|
||||
|
||||
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
|
||||
|
||||
```go
|
||||
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
||||
http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
|
||||
```
|
||||
|
||||
### As a client
|
||||
|
||||
See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`.
|
||||
|
||||
```go
|
||||
http.Client{
|
||||
Transport: &http3.RoundTripper{},
|
||||
}
|
||||
```
|
||||
|
||||
## Projects using quic-go
|
||||
|
||||
| Project | Description | Stars |
|
||||
|-----------------------------------------------------------|---------------------------------------------------------------------------------------------------------|-------|
|
||||
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. |  |
|
||||
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support |  |
|
||||
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS |  |
|
||||
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins |  |
|
||||
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others |  |
|
||||
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. |  |
|
||||
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization |  |
|
||||
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy |  |
|
||||
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions |  |
|
||||
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System |  |
|
||||
|
||||
## Contributing
|
||||
|
||||
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.
|
80
vendor/github.com/quic-go/quic-go/buffer_pool.go
generated
vendored
Normal file
80
vendor/github.com/quic-go/quic-go/buffer_pool.go
generated
vendored
Normal file
|
@ -0,0 +1,80 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type packetBuffer struct {
|
||||
Data []byte
|
||||
|
||||
// refCount counts how many packets Data is used in.
|
||||
// It doesn't support concurrent use.
|
||||
// It is > 1 when used for coalesced packet.
|
||||
refCount int
|
||||
}
|
||||
|
||||
// Split increases the refCount.
|
||||
// It must be called when a packet buffer is used for more than one packet,
|
||||
// e.g. when splitting coalesced packets.
|
||||
func (b *packetBuffer) Split() {
|
||||
b.refCount++
|
||||
}
|
||||
|
||||
// Decrement decrements the reference counter.
|
||||
// It doesn't put the buffer back into the pool.
|
||||
func (b *packetBuffer) Decrement() {
|
||||
b.refCount--
|
||||
if b.refCount < 0 {
|
||||
panic("negative packetBuffer refCount")
|
||||
}
|
||||
}
|
||||
|
||||
// MaybeRelease puts the packet buffer back into the pool,
|
||||
// if the reference counter already reached 0.
|
||||
func (b *packetBuffer) MaybeRelease() {
|
||||
// only put the packetBuffer back if it's not used any more
|
||||
if b.refCount == 0 {
|
||||
b.putBack()
|
||||
}
|
||||
}
|
||||
|
||||
// Release puts back the packet buffer into the pool.
|
||||
// It should be called when processing is definitely finished.
|
||||
func (b *packetBuffer) Release() {
|
||||
b.Decrement()
|
||||
if b.refCount != 0 {
|
||||
panic("packetBuffer refCount not zero")
|
||||
}
|
||||
b.putBack()
|
||||
}
|
||||
|
||||
// Len returns the length of Data
|
||||
func (b *packetBuffer) Len() protocol.ByteCount {
|
||||
return protocol.ByteCount(len(b.Data))
|
||||
}
|
||||
|
||||
func (b *packetBuffer) putBack() {
|
||||
if cap(b.Data) != int(protocol.MaxPacketBufferSize) {
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
bufferPool.Put(b)
|
||||
}
|
||||
|
||||
var bufferPool sync.Pool
|
||||
|
||||
func getPacketBuffer() *packetBuffer {
|
||||
buf := bufferPool.Get().(*packetBuffer)
|
||||
buf.refCount = 1
|
||||
buf.Data = buf.Data[:0]
|
||||
return buf
|
||||
}
|
||||
|
||||
func init() {
|
||||
bufferPool.New = func() interface{} {
|
||||
return &packetBuffer{
|
||||
Data: make([]byte, 0, protocol.MaxPacketBufferSize),
|
||||
}
|
||||
}
|
||||
}
|
332
vendor/github.com/quic-go/quic-go/client.go
generated
vendored
Normal file
332
vendor/github.com/quic-go/quic-go/client.go
generated
vendored
Normal file
|
@ -0,0 +1,332 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
sconn sendConn
|
||||
// If the client is created with DialAddr, we create a packet conn.
|
||||
// If it is started with Dial, we take a packet conn as a parameter.
|
||||
createdPacketConn bool
|
||||
|
||||
use0RTT bool
|
||||
|
||||
packetHandlers packetHandlerManager
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
initialPacketNumber protocol.PacketNumber
|
||||
hasNegotiatedVersion bool
|
||||
version protocol.VersionNumber
|
||||
|
||||
handshakeChan chan struct{}
|
||||
|
||||
conn quicConn
|
||||
|
||||
tracer logging.ConnectionTracer
|
||||
tracingID uint64
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// make it possible to mock connection ID for initial generation in the tests
|
||||
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
|
||||
func DialAddr(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Connection, error) {
|
||||
return DialAddrContext(context.Background(), addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
// The hostname for SNI is taken from the given address.
|
||||
// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites.
|
||||
func DialAddrEarly(
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (EarlyConnection, error) {
|
||||
return DialAddrEarlyContext(context.Background(), addr, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context.
|
||||
// See DialAddrEarly for details
|
||||
func DialAddrEarlyContext(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (EarlyConnection, error) {
|
||||
conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
|
||||
// See DialAddr for details.
|
||||
func DialAddrContext(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Connection, error) {
|
||||
return dialAddrContext(ctx, addr, tlsConf, config, false)
|
||||
}
|
||||
|
||||
func dialAddrContext(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
use0RTT bool,
|
||||
) (quicConn, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true)
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
|
||||
// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
|
||||
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
|
||||
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
|
||||
// packets. The same PacketConn can be used for multiple calls to Dial and
|
||||
// Listen, QUIC connection IDs are used for demultiplexing the different
|
||||
// connections. The host parameter is used for SNI. The tls.Config must define
|
||||
// an application protocol (using NextProtos).
|
||||
func Dial(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Connection, error) {
|
||||
return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false)
|
||||
}
|
||||
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
|
||||
// The same PacketConn can be used for multiple calls to Dial and Listen,
|
||||
// QUIC connection IDs are used for demultiplexing the different connections.
|
||||
// The host parameter is used for SNI.
|
||||
// The tls.Config must define an application protocol (using NextProtos).
|
||||
func DialEarly(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (EarlyConnection, error) {
|
||||
return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
|
||||
}
|
||||
|
||||
// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
|
||||
// See DialEarly for details.
|
||||
func DialEarlyContext(
|
||||
ctx context.Context,
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (EarlyConnection, error) {
|
||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false)
|
||||
}
|
||||
|
||||
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
|
||||
// See Dial for details.
|
||||
func DialContext(
|
||||
ctx context.Context,
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Connection, error) {
|
||||
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false)
|
||||
}
|
||||
|
||||
func dialContext(
|
||||
ctx context.Context,
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
host string,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
use0RTT bool,
|
||||
createdPacketConn bool,
|
||||
) (quicConn, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
if err := validateConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config = populateClientConfig(config, createdPacketConn)
|
||||
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.packetHandlers = packetHandlers
|
||||
|
||||
c.tracingID = nextConnTracingID()
|
||||
if c.config.Tracer != nil {
|
||||
c.tracer = c.config.Tracer.TracerForConnection(
|
||||
context.WithValue(ctx, ConnectionTracingKey, c.tracingID),
|
||||
protocol.PerspectiveClient,
|
||||
c.destConnID,
|
||||
)
|
||||
}
|
||||
if c.tracer != nil {
|
||||
c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
}
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func newClient(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
host string,
|
||||
use0RTT bool,
|
||||
createdPacketConn bool,
|
||||
) (*client, error) {
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
tlsConf = tlsConf.Clone()
|
||||
}
|
||||
if tlsConf.ServerName == "" {
|
||||
sni, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
|
||||
sni = host
|
||||
}
|
||||
|
||||
tlsConf.ServerName = sni
|
||||
}
|
||||
|
||||
// check that all versions are actually supported
|
||||
if config != nil {
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destConnID, err := generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &client{
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sconn: newSendPconn(pconn, remoteAddr),
|
||||
createdPacketConn: createdPacketConn,
|
||||
use0RTT: use0RTT,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
c.conn = newClientConnection(
|
||||
c.sconn,
|
||||
c.packetHandlers,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.config,
|
||||
c.tlsConf,
|
||||
c.initialPacketNumber,
|
||||
c.use0RTT,
|
||||
c.hasNegotiatedVersion,
|
||||
c.tracer,
|
||||
c.tracingID,
|
||||
c.logger,
|
||||
c.version,
|
||||
)
|
||||
c.packetHandlers.Add(c.srcConnID, c.conn)
|
||||
|
||||
errorChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := c.conn.run() // returns as soon as the connection is closed
|
||||
|
||||
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn {
|
||||
c.packetHandlers.Destroy()
|
||||
}
|
||||
errorChan <- err
|
||||
}()
|
||||
|
||||
// only set when we're using 0-RTT
|
||||
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
|
||||
var earlyConnChan <-chan struct{}
|
||||
if c.use0RTT {
|
||||
earlyConnChan = c.conn.earlyConnReady()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.conn.shutdown()
|
||||
return ctx.Err()
|
||||
case err := <-errorChan:
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||
c.version = recreateErr.nextVersion
|
||||
c.hasNegotiatedVersion = true
|
||||
return c.dial(ctx)
|
||||
}
|
||||
return err
|
||||
case <-earlyConnChan:
|
||||
// ready to send 0-RTT data
|
||||
return nil
|
||||
case <-c.conn.HandshakeComplete().Done():
|
||||
// handshake successfully completed
|
||||
return nil
|
||||
}
|
||||
}
|
64
vendor/github.com/quic-go/quic-go/closed_conn.go
generated
vendored
Normal file
64
vendor/github.com/quic-go/quic-go/closed_conn.go
generated
vendored
Normal file
|
@ -0,0 +1,64 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A closedLocalConn is a connection that we closed locally.
|
||||
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
|
||||
// with an exponential backoff.
|
||||
type closedLocalConn struct {
|
||||
counter uint32
|
||||
perspective protocol.Perspective
|
||||
logger utils.Logger
|
||||
|
||||
sendPacket func(net.Addr, *packetInfo)
|
||||
}
|
||||
|
||||
var _ packetHandler = &closedLocalConn{}
|
||||
|
||||
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
||||
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
|
||||
return &closedLocalConn{
|
||||
sendPacket: sendPacket,
|
||||
perspective: pers,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) handlePacket(p *receivedPacket) {
|
||||
c.counter++
|
||||
// exponential backoff
|
||||
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
|
||||
if bits.OnesCount32(c.counter) != 1 {
|
||||
return
|
||||
}
|
||||
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
|
||||
c.sendPacket(p.remoteAddr, p.info)
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) shutdown() {}
|
||||
func (c *closedLocalConn) destroy(error) {}
|
||||
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
|
||||
|
||||
// A closedRemoteConn is a connection that was closed remotely.
|
||||
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
||||
// We can just ignore those packets.
|
||||
type closedRemoteConn struct {
|
||||
perspective protocol.Perspective
|
||||
}
|
||||
|
||||
var _ packetHandler = &closedRemoteConn{}
|
||||
|
||||
func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
|
||||
return &closedRemoteConn{perspective: pers}
|
||||
}
|
||||
|
||||
func (s *closedRemoteConn) handlePacket(*receivedPacket) {}
|
||||
func (s *closedRemoteConn) shutdown() {}
|
||||
func (s *closedRemoteConn) destroy(error) {}
|
||||
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
|
22
vendor/github.com/quic-go/quic-go/codecov.yml
generated
vendored
Normal file
22
vendor/github.com/quic-go/quic-go/codecov.yml
generated
vendored
Normal file
|
@ -0,0 +1,22 @@
|
|||
coverage:
|
||||
round: nearest
|
||||
ignore:
|
||||
- streams_map_incoming_bidi.go
|
||||
- streams_map_incoming_uni.go
|
||||
- streams_map_outgoing_bidi.go
|
||||
- streams_map_outgoing_uni.go
|
||||
- http3/gzip_reader.go
|
||||
- interop/
|
||||
- internal/ackhandler/packet_linkedlist.go
|
||||
- internal/utils/byteinterval_linkedlist.go
|
||||
- internal/utils/newconnectionid_linkedlist.go
|
||||
- internal/utils/packetinterval_linkedlist.go
|
||||
- internal/utils/linkedlist/linkedlist.go
|
||||
- logging/null_tracer.go
|
||||
- fuzzing/
|
||||
- metrics/
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
threshold: 0.5
|
||||
patch: false
|
141
vendor/github.com/quic-go/quic-go/config.go
generated
vendored
Normal file
141
vendor/github.com/quic-go/quic-go/config.go
generated
vendored
Normal file
|
@ -0,0 +1,141 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Clone clones a Config
|
||||
func (c *Config) Clone() *Config {
|
||||
copy := *c
|
||||
return ©
|
||||
}
|
||||
|
||||
func (c *Config) handshakeTimeout() time.Duration {
|
||||
return utils.Max(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout)
|
||||
}
|
||||
|
||||
func validateConfig(config *Config) error {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
if config.MaxIncomingStreams > 1<<60 {
|
||||
return errors.New("invalid value for Config.MaxIncomingStreams")
|
||||
}
|
||||
if config.MaxIncomingUniStreams > 1<<60 {
|
||||
return errors.New("invalid value for Config.MaxIncomingUniStreams")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
config = populateConfig(config, protocol.DefaultConnectionIDLength)
|
||||
if config.MaxTokenAge == 0 {
|
||||
config.MaxTokenAge = protocol.TokenValidity
|
||||
}
|
||||
if config.MaxRetryTokenAge == 0 {
|
||||
config.MaxRetryTokenAge = protocol.RetryTokenValidity
|
||||
}
|
||||
if config.RequireAddressValidation == nil {
|
||||
config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
|
||||
defaultConnIDLen := protocol.DefaultConnectionIDLength
|
||||
if createdPacketConn {
|
||||
defaultConnIDLen = 0
|
||||
}
|
||||
|
||||
config = populateConfig(config, defaultConnIDLen)
|
||||
return config
|
||||
}
|
||||
|
||||
func populateConfig(config *Config, defaultConnIDLen int) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
versions := config.Versions
|
||||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
}
|
||||
conIDLen := config.ConnectionIDLength
|
||||
if config.ConnectionIDLength == 0 {
|
||||
conIDLen = defaultConnIDLen
|
||||
}
|
||||
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
|
||||
if config.HandshakeIdleTimeout != 0 {
|
||||
handshakeIdleTimeout = config.HandshakeIdleTimeout
|
||||
}
|
||||
idleTimeout := protocol.DefaultIdleTimeout
|
||||
if config.MaxIdleTimeout != 0 {
|
||||
idleTimeout = config.MaxIdleTimeout
|
||||
}
|
||||
initialStreamReceiveWindow := config.InitialStreamReceiveWindow
|
||||
if initialStreamReceiveWindow == 0 {
|
||||
initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData
|
||||
}
|
||||
maxStreamReceiveWindow := config.MaxStreamReceiveWindow
|
||||
if maxStreamReceiveWindow == 0 {
|
||||
maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
|
||||
}
|
||||
initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow
|
||||
if initialConnectionReceiveWindow == 0 {
|
||||
initialConnectionReceiveWindow = protocol.DefaultInitialMaxData
|
||||
}
|
||||
maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow
|
||||
if maxConnectionReceiveWindow == 0 {
|
||||
maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
maxIncomingStreams := config.MaxIncomingStreams
|
||||
if maxIncomingStreams == 0 {
|
||||
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
|
||||
} else if maxIncomingStreams < 0 {
|
||||
maxIncomingStreams = 0
|
||||
}
|
||||
maxIncomingUniStreams := config.MaxIncomingUniStreams
|
||||
if maxIncomingUniStreams == 0 {
|
||||
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
|
||||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
connIDGenerator := config.ConnectionIDGenerator
|
||||
if connIDGenerator == nil {
|
||||
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
MaxTokenAge: config.MaxTokenAge,
|
||||
MaxRetryTokenAge: config.MaxRetryTokenAge,
|
||||
RequireAddressValidation: config.RequireAddressValidation,
|
||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
||||
InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
|
||||
MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
|
||||
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
ConnectionIDLength: conIDLen,
|
||||
ConnectionIDGenerator: connIDGenerator,
|
||||
StatelessResetKey: config.StatelessResetKey,
|
||||
TokenStore: config.TokenStore,
|
||||
EnableDatagrams: config.EnableDatagrams,
|
||||
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
|
||||
DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets,
|
||||
Allow0RTT: config.Allow0RTT,
|
||||
Tracer: config.Tracer,
|
||||
}
|
||||
}
|
139
vendor/github.com/quic-go/quic-go/conn_id_generator.go
generated
vendored
Normal file
139
vendor/github.com/quic-go/quic-go/conn_id_generator.go
generated
vendored
Normal file
|
@ -0,0 +1,139 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type connIDGenerator struct {
|
||||
generator ConnectionIDGenerator
|
||||
highestSeq uint64
|
||||
|
||||
activeSrcConnIDs map[uint64]protocol.ConnectionID
|
||||
initialClientDestConnID *protocol.ConnectionID // nil for the client
|
||||
|
||||
addConnectionID func(protocol.ConnectionID)
|
||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
||||
removeConnectionID func(protocol.ConnectionID)
|
||||
retireConnectionID func(protocol.ConnectionID)
|
||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
|
||||
queueControlFrame func(wire.Frame)
|
||||
}
|
||||
|
||||
func newConnIDGenerator(
|
||||
initialConnectionID protocol.ConnectionID,
|
||||
initialClientDestConnID *protocol.ConnectionID, // nil for the client
|
||||
addConnectionID func(protocol.ConnectionID),
|
||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
||||
removeConnectionID func(protocol.ConnectionID),
|
||||
retireConnectionID func(protocol.ConnectionID),
|
||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
|
||||
queueControlFrame func(wire.Frame),
|
||||
generator ConnectionIDGenerator,
|
||||
) *connIDGenerator {
|
||||
m := &connIDGenerator{
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
addConnectionID: addConnectionID,
|
||||
getStatelessResetToken: getStatelessResetToken,
|
||||
removeConnectionID: removeConnectionID,
|
||||
retireConnectionID: retireConnectionID,
|
||||
replaceWithClosed: replaceWithClosed,
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
m.activeSrcConnIDs[0] = initialConnectionID
|
||||
m.initialClientDestConnID = initialClientDestConnID
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
|
||||
if m.generator.ConnectionIDLen() == 0 {
|
||||
return nil
|
||||
}
|
||||
// The active_connection_id_limit transport parameter is the number of
|
||||
// connection IDs the peer will store. This limit includes the connection ID
|
||||
// used during the handshake, and the one sent in the preferred_address
|
||||
// transport parameter.
|
||||
// We currently don't send the preferred_address transport parameter,
|
||||
// so we can issue (limit - 1) connection IDs.
|
||||
for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ {
|
||||
if err := m.issueNewConnID(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
|
||||
if seq > m.highestSeq {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
|
||||
}
|
||||
}
|
||||
connID, ok := m.activeSrcConnIDs[seq]
|
||||
// We might already have deleted this connection ID, if this is a duplicate frame.
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if connID == sentWithDestConnID {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
|
||||
}
|
||||
}
|
||||
m.retireConnectionID(connID)
|
||||
delete(m.activeSrcConnIDs, seq)
|
||||
// Don't issue a replacement for the initial connection ID.
|
||||
if seq == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.issueNewConnID()
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) issueNewConnID() error {
|
||||
connID, err := m.generator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.activeSrcConnIDs[m.highestSeq+1] = connID
|
||||
m.addConnectionID(connID)
|
||||
m.queueControlFrame(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: m.highestSeq + 1,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: m.getStatelessResetToken(connID),
|
||||
})
|
||||
m.highestSeq++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) SetHandshakeComplete() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.retireConnectionID(*m.initialClientDestConnID)
|
||||
m.initialClientDestConnID = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) RemoveAll() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.removeConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
m.removeConnectionID(connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
|
||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
||||
if m.initialClientDestConnID != nil {
|
||||
connIDs = append(connIDs, *m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
connIDs = append(connIDs, connID)
|
||||
}
|
||||
m.replaceWithClosed(connIDs, pers, connClose)
|
||||
}
|
214
vendor/github.com/quic-go/quic-go/conn_id_manager.go
generated
vendored
Normal file
214
vendor/github.com/quic-go/quic-go/conn_id_manager.go
generated
vendored
Normal file
|
@ -0,0 +1,214 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type newConnID struct {
|
||||
SequenceNumber uint64
|
||||
ConnectionID protocol.ConnectionID
|
||||
StatelessResetToken protocol.StatelessResetToken
|
||||
}
|
||||
|
||||
type connIDManager struct {
|
||||
queue list.List[newConnID]
|
||||
|
||||
handshakeComplete bool
|
||||
activeSequenceNumber uint64
|
||||
highestRetired uint64
|
||||
activeConnectionID protocol.ConnectionID
|
||||
activeStatelessResetToken *protocol.StatelessResetToken
|
||||
|
||||
// We change the connection ID after sending on average
|
||||
// protocol.PacketsPerConnectionID packets. The actual value is randomized
|
||||
// hide the packet loss rate from on-path observers.
|
||||
rand utils.Rand
|
||||
packetsSinceLastChange uint32
|
||||
packetsPerConnectionID uint32
|
||||
|
||||
addStatelessResetToken func(protocol.StatelessResetToken)
|
||||
removeStatelessResetToken func(protocol.StatelessResetToken)
|
||||
queueControlFrame func(wire.Frame)
|
||||
}
|
||||
|
||||
func newConnIDManager(
|
||||
initialDestConnID protocol.ConnectionID,
|
||||
addStatelessResetToken func(protocol.StatelessResetToken),
|
||||
removeStatelessResetToken func(protocol.StatelessResetToken),
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *connIDManager {
|
||||
return &connIDManager{
|
||||
activeConnectionID: initialDestConnID,
|
||||
addStatelessResetToken: addStatelessResetToken,
|
||||
removeStatelessResetToken: removeStatelessResetToken,
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
|
||||
return h.addConnectionID(1, connID, resetToken)
|
||||
}
|
||||
|
||||
func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
|
||||
if err := h.add(f); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
|
||||
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
|
||||
// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
|
||||
// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
|
||||
if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired {
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: f.SequenceNumber,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retire elements in the queue.
|
||||
// Doesn't retire the active connection ID.
|
||||
if f.RetirePriorTo > h.highestRetired {
|
||||
var next *list.Element[newConnID]
|
||||
for el := h.queue.Front(); el != nil; el = next {
|
||||
if el.Value.SequenceNumber >= f.RetirePriorTo {
|
||||
break
|
||||
}
|
||||
next = el.Next()
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: el.Value.SequenceNumber,
|
||||
})
|
||||
h.queue.Remove(el)
|
||||
}
|
||||
h.highestRetired = f.RetirePriorTo
|
||||
}
|
||||
|
||||
if f.SequenceNumber == h.activeSequenceNumber {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Retire the active connection ID, if necessary.
|
||||
if h.activeSequenceNumber < f.RetirePriorTo {
|
||||
// The queue is guaranteed to have at least one element at this point.
|
||||
h.updateConnectionID()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
|
||||
// insert a new element at the end
|
||||
if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq {
|
||||
h.queue.PushBack(newConnID{
|
||||
SequenceNumber: seq,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: resetToken,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
// insert a new element somewhere in the middle
|
||||
for el := h.queue.Front(); el != nil; el = el.Next() {
|
||||
if el.Value.SequenceNumber == seq {
|
||||
if el.Value.ConnectionID != connID {
|
||||
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
|
||||
}
|
||||
if el.Value.StatelessResetToken != resetToken {
|
||||
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
|
||||
}
|
||||
break
|
||||
}
|
||||
if el.Value.SequenceNumber > seq {
|
||||
h.queue.InsertBefore(newConnID{
|
||||
SequenceNumber: seq,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: resetToken,
|
||||
}, el)
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) updateConnectionID() {
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: h.activeSequenceNumber,
|
||||
})
|
||||
h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber)
|
||||
if h.activeStatelessResetToken != nil {
|
||||
h.removeStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
|
||||
front := h.queue.Remove(h.queue.Front())
|
||||
h.activeSequenceNumber = front.SequenceNumber
|
||||
h.activeConnectionID = front.ConnectionID
|
||||
h.activeStatelessResetToken = &front.StatelessResetToken
|
||||
h.packetsSinceLastChange = 0
|
||||
h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID))
|
||||
h.addStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
|
||||
func (h *connIDManager) Close() {
|
||||
if h.activeStatelessResetToken != nil {
|
||||
h.removeStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
}
|
||||
|
||||
// is called when the server performs a Retry
|
||||
// and when the server changes the connection ID in the first Initial sent
|
||||
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
|
||||
if h.activeSequenceNumber != 0 {
|
||||
panic("expected first connection ID to have sequence number 0")
|
||||
}
|
||||
h.activeConnectionID = newConnID
|
||||
}
|
||||
|
||||
// is called when the server provides a stateless reset token in the transport parameters
|
||||
func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
|
||||
if h.activeSequenceNumber != 0 {
|
||||
panic("expected first connection ID to have sequence number 0")
|
||||
}
|
||||
h.activeStatelessResetToken = &token
|
||||
h.addStatelessResetToken(token)
|
||||
}
|
||||
|
||||
func (h *connIDManager) SentPacket() {
|
||||
h.packetsSinceLastChange++
|
||||
}
|
||||
|
||||
func (h *connIDManager) shouldUpdateConnID() bool {
|
||||
if !h.handshakeComplete {
|
||||
return false
|
||||
}
|
||||
// initiate the first change as early as possible (after handshake completion)
|
||||
if h.queue.Len() > 0 && h.activeSequenceNumber == 0 {
|
||||
return true
|
||||
}
|
||||
// For later changes, only change if
|
||||
// 1. The queue of connection IDs is filled more than 50%.
|
||||
// 2. We sent at least PacketsPerConnectionID packets
|
||||
return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs &&
|
||||
h.packetsSinceLastChange >= h.packetsPerConnectionID
|
||||
}
|
||||
|
||||
func (h *connIDManager) Get() protocol.ConnectionID {
|
||||
if h.shouldUpdateConnID() {
|
||||
h.updateConnectionID()
|
||||
}
|
||||
return h.activeConnectionID
|
||||
}
|
||||
|
||||
func (h *connIDManager) SetHandshakeComplete() {
|
||||
h.handshakeComplete = true
|
||||
}
|
2185
vendor/github.com/quic-go/quic-go/connection.go
generated
vendored
Normal file
2185
vendor/github.com/quic-go/quic-go/connection.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load diff
51
vendor/github.com/quic-go/quic-go/connection_timer.go
generated
vendored
Normal file
51
vendor/github.com/quic-go/quic-go/connection_timer.go
generated
vendored
Normal file
|
@ -0,0 +1,51 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
var deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine
|
||||
|
||||
type connectionTimer struct {
|
||||
timer *utils.Timer
|
||||
last time.Time
|
||||
}
|
||||
|
||||
func newTimer() *connectionTimer {
|
||||
return &connectionTimer{timer: utils.NewTimer()}
|
||||
}
|
||||
|
||||
func (t *connectionTimer) SetRead() {
|
||||
if deadline := t.timer.Deadline(); deadline != deadlineSendImmediately {
|
||||
t.last = deadline
|
||||
}
|
||||
t.timer.SetRead()
|
||||
}
|
||||
|
||||
func (t *connectionTimer) Chan() <-chan time.Time {
|
||||
return t.timer.Chan()
|
||||
}
|
||||
|
||||
// SetTimer resets the timer.
|
||||
// It makes sure that the deadline is strictly increasing.
|
||||
// This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet.
|
||||
// This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately.
|
||||
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, ackAlarm, lossTime, pacing time.Time) {
|
||||
deadline := idleTimeoutOrKeepAlive
|
||||
if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) {
|
||||
deadline = ackAlarm
|
||||
}
|
||||
if !lossTime.IsZero() && lossTime.Before(deadline) && lossTime.After(t.last) {
|
||||
deadline = lossTime
|
||||
}
|
||||
if !pacing.IsZero() && pacing.Before(deadline) {
|
||||
deadline = pacing
|
||||
}
|
||||
t.timer.Reset(deadline)
|
||||
}
|
||||
|
||||
func (t *connectionTimer) Stop() {
|
||||
t.timer.Stop()
|
||||
}
|
115
vendor/github.com/quic-go/quic-go/crypto_stream.go
generated
vendored
Normal file
115
vendor/github.com/quic-go/quic-go/crypto_stream.go
generated
vendored
Normal file
|
@ -0,0 +1,115 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStream interface {
|
||||
// for receiving data
|
||||
HandleCryptoFrame(*wire.CryptoFrame) error
|
||||
GetCryptoData() []byte
|
||||
Finish() error
|
||||
// for sending data
|
||||
io.Writer
|
||||
HasData() bool
|
||||
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
|
||||
}
|
||||
|
||||
type cryptoStreamImpl struct {
|
||||
queue *frameSorter
|
||||
msgBuf []byte
|
||||
|
||||
highestOffset protocol.ByteCount
|
||||
finished bool
|
||||
|
||||
writeOffset protocol.ByteCount
|
||||
writeBuf []byte
|
||||
}
|
||||
|
||||
func newCryptoStream() cryptoStream {
|
||||
return &cryptoStreamImpl{queue: newFrameSorter()}
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
|
||||
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.CryptoBufferExceeded,
|
||||
ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset),
|
||||
}
|
||||
}
|
||||
if s.finished {
|
||||
if highestOffset > s.highestOffset {
|
||||
// reject crypto data received after this stream was already finished
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "received crypto data after change of encryption level",
|
||||
}
|
||||
}
|
||||
// ignore data with a smaller offset than the highest received
|
||||
// could e.g. be a retransmission
|
||||
return nil
|
||||
}
|
||||
s.highestOffset = utils.Max(s.highestOffset, highestOffset)
|
||||
if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
_, data, _ := s.queue.Pop()
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
s.msgBuf = append(s.msgBuf, data...)
|
||||
}
|
||||
}
|
||||
|
||||
// GetCryptoData retrieves data that was received in CRYPTO frames
|
||||
func (s *cryptoStreamImpl) GetCryptoData() []byte {
|
||||
if len(s.msgBuf) < 4 {
|
||||
return nil
|
||||
}
|
||||
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
|
||||
if len(s.msgBuf) < msgLen {
|
||||
return nil
|
||||
}
|
||||
msg := make([]byte, msgLen)
|
||||
copy(msg, s.msgBuf[:msgLen])
|
||||
s.msgBuf = s.msgBuf[msgLen:]
|
||||
return msg
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) Finish() error {
|
||||
if s.queue.HasMoreData() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "encryption level changed, but crypto stream has more data to read",
|
||||
}
|
||||
}
|
||||
s.finished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Writes writes data that should be sent out in CRYPTO frames
|
||||
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
|
||||
s.writeBuf = append(s.writeBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) HasData() bool {
|
||||
return len(s.writeBuf) > 0
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||
n := utils.Min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
|
||||
f.Data = s.writeBuf[:n]
|
||||
s.writeBuf = s.writeBuf[n:]
|
||||
s.writeOffset += n
|
||||
return f
|
||||
}
|
61
vendor/github.com/quic-go/quic-go/crypto_stream_manager.go
generated
vendored
Normal file
61
vendor/github.com/quic-go/quic-go/crypto_stream_manager.go
generated
vendored
Normal file
|
@ -0,0 +1,61 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoDataHandler interface {
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
}
|
||||
|
||||
type cryptoStreamManager struct {
|
||||
cryptoHandler cryptoDataHandler
|
||||
|
||||
initialStream cryptoStream
|
||||
handshakeStream cryptoStream
|
||||
oneRTTStream cryptoStream
|
||||
}
|
||||
|
||||
func newCryptoStreamManager(
|
||||
cryptoHandler cryptoDataHandler,
|
||||
initialStream cryptoStream,
|
||||
handshakeStream cryptoStream,
|
||||
oneRTTStream cryptoStream,
|
||||
) *cryptoStreamManager {
|
||||
return &cryptoStreamManager{
|
||||
cryptoHandler: cryptoHandler,
|
||||
initialStream: initialStream,
|
||||
handshakeStream: handshakeStream,
|
||||
oneRTTStream: oneRTTStream,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
|
||||
var str cryptoStream
|
||||
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
str = m.initialStream
|
||||
case protocol.EncryptionHandshake:
|
||||
str = m.handshakeStream
|
||||
case protocol.Encryption1RTT:
|
||||
str = m.oneRTTStream
|
||||
default:
|
||||
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
if err := str.HandleCryptoFrame(frame); err != nil {
|
||||
return false, err
|
||||
}
|
||||
for {
|
||||
data := str.GetCryptoData()
|
||||
if data == nil {
|
||||
return false, nil
|
||||
}
|
||||
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
|
||||
return true, str.Finish()
|
||||
}
|
||||
}
|
||||
}
|
99
vendor/github.com/quic-go/quic-go/datagram_queue.go
generated
vendored
Normal file
99
vendor/github.com/quic-go/quic-go/datagram_queue.go
generated
vendored
Normal file
|
@ -0,0 +1,99 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type datagramQueue struct {
|
||||
sendQueue chan *wire.DatagramFrame
|
||||
nextFrame *wire.DatagramFrame
|
||||
rcvQueue chan []byte
|
||||
|
||||
closeErr error
|
||||
closed chan struct{}
|
||||
|
||||
hasData func()
|
||||
|
||||
dequeued chan struct{}
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
|
||||
return &datagramQueue{
|
||||
hasData: hasData,
|
||||
sendQueue: make(chan *wire.DatagramFrame, 1),
|
||||
rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen),
|
||||
dequeued: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// AddAndWait queues a new DATAGRAM frame for sending.
|
||||
// It blocks until the frame has been dequeued.
|
||||
func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error {
|
||||
select {
|
||||
case h.sendQueue <- f:
|
||||
h.hasData()
|
||||
case <-h.closed:
|
||||
return h.closeErr
|
||||
}
|
||||
|
||||
select {
|
||||
case <-h.dequeued:
|
||||
return nil
|
||||
case <-h.closed:
|
||||
return h.closeErr
|
||||
}
|
||||
}
|
||||
|
||||
// Peek gets the next DATAGRAM frame for sending.
|
||||
// If actually sent out, Pop needs to be called before the next call to Peek.
|
||||
func (h *datagramQueue) Peek() *wire.DatagramFrame {
|
||||
if h.nextFrame != nil {
|
||||
return h.nextFrame
|
||||
}
|
||||
select {
|
||||
case h.nextFrame = <-h.sendQueue:
|
||||
h.dequeued <- struct{}{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return h.nextFrame
|
||||
}
|
||||
|
||||
func (h *datagramQueue) Pop() {
|
||||
if h.nextFrame == nil {
|
||||
panic("datagramQueue BUG: Pop called for nil frame")
|
||||
}
|
||||
h.nextFrame = nil
|
||||
}
|
||||
|
||||
// HandleDatagramFrame handles a received DATAGRAM frame.
|
||||
func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
|
||||
data := make([]byte, len(f.Data))
|
||||
copy(data, f.Data)
|
||||
select {
|
||||
case h.rcvQueue <- data:
|
||||
default:
|
||||
h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data))
|
||||
}
|
||||
}
|
||||
|
||||
// Receive gets a received DATAGRAM frame.
|
||||
func (h *datagramQueue) Receive() ([]byte, error) {
|
||||
select {
|
||||
case data := <-h.rcvQueue:
|
||||
return data, nil
|
||||
case <-h.closed:
|
||||
return nil, h.closeErr
|
||||
}
|
||||
}
|
||||
|
||||
func (h *datagramQueue) CloseWithError(e error) {
|
||||
h.closeErr = e
|
||||
close(h.closed)
|
||||
}
|
63
vendor/github.com/quic-go/quic-go/errors.go
generated
vendored
Normal file
63
vendor/github.com/quic-go/quic-go/errors.go
generated
vendored
Normal file
|
@ -0,0 +1,63 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
)
|
||||
|
||||
type (
|
||||
TransportError = qerr.TransportError
|
||||
ApplicationError = qerr.ApplicationError
|
||||
VersionNegotiationError = qerr.VersionNegotiationError
|
||||
StatelessResetError = qerr.StatelessResetError
|
||||
IdleTimeoutError = qerr.IdleTimeoutError
|
||||
HandshakeTimeoutError = qerr.HandshakeTimeoutError
|
||||
)
|
||||
|
||||
type (
|
||||
TransportErrorCode = qerr.TransportErrorCode
|
||||
ApplicationErrorCode = qerr.ApplicationErrorCode
|
||||
StreamErrorCode = qerr.StreamErrorCode
|
||||
)
|
||||
|
||||
const (
|
||||
NoError = qerr.NoError
|
||||
InternalError = qerr.InternalError
|
||||
ConnectionRefused = qerr.ConnectionRefused
|
||||
FlowControlError = qerr.FlowControlError
|
||||
StreamLimitError = qerr.StreamLimitError
|
||||
StreamStateError = qerr.StreamStateError
|
||||
FinalSizeError = qerr.FinalSizeError
|
||||
FrameEncodingError = qerr.FrameEncodingError
|
||||
TransportParameterError = qerr.TransportParameterError
|
||||
ConnectionIDLimitError = qerr.ConnectionIDLimitError
|
||||
ProtocolViolation = qerr.ProtocolViolation
|
||||
InvalidToken = qerr.InvalidToken
|
||||
ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode
|
||||
CryptoBufferExceeded = qerr.CryptoBufferExceeded
|
||||
KeyUpdateError = qerr.KeyUpdateError
|
||||
AEADLimitReached = qerr.AEADLimitReached
|
||||
NoViablePathError = qerr.NoViablePathError
|
||||
)
|
||||
|
||||
// A StreamError is used for Stream.CancelRead and Stream.CancelWrite.
|
||||
// It is also returned from Stream.Read and Stream.Write if the peer canceled reading or writing.
|
||||
type StreamError struct {
|
||||
StreamID StreamID
|
||||
ErrorCode StreamErrorCode
|
||||
Remote bool
|
||||
}
|
||||
|
||||
func (e *StreamError) Is(target error) bool {
|
||||
_, ok := target.(*StreamError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *StreamError) Error() string {
|
||||
pers := "local"
|
||||
if e.Remote {
|
||||
pers = "remote"
|
||||
}
|
||||
return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode)
|
||||
}
|
237
vendor/github.com/quic-go/quic-go/frame_sorter.go
generated
vendored
Normal file
237
vendor/github.com/quic-go/quic-go/frame_sorter.go
generated
vendored
Normal file
|
@ -0,0 +1,237 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||
)
|
||||
|
||||
// byteInterval is an interval from one ByteCount to the other
|
||||
type byteInterval struct {
|
||||
Start protocol.ByteCount
|
||||
End protocol.ByteCount
|
||||
}
|
||||
|
||||
var byteIntervalElementPool sync.Pool
|
||||
|
||||
func init() {
|
||||
byteIntervalElementPool = *list.NewPool[byteInterval]()
|
||||
}
|
||||
|
||||
type frameSorterEntry struct {
|
||||
Data []byte
|
||||
DoneCb func()
|
||||
}
|
||||
|
||||
type frameSorter struct {
|
||||
queue map[protocol.ByteCount]frameSorterEntry
|
||||
readPos protocol.ByteCount
|
||||
gaps *list.List[byteInterval]
|
||||
}
|
||||
|
||||
var errDuplicateStreamData = errors.New("duplicate stream data")
|
||||
|
||||
func newFrameSorter() *frameSorter {
|
||||
s := frameSorter{
|
||||
gaps: list.NewWithPool[byteInterval](&byteIntervalElementPool),
|
||||
queue: make(map[protocol.ByteCount]frameSorterEntry),
|
||||
}
|
||||
s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount})
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error {
|
||||
err := s.push(data, offset, doneCb)
|
||||
if err == errDuplicateStreamData {
|
||||
if doneCb != nil {
|
||||
doneCb()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error {
|
||||
if len(data) == 0 {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
start := offset
|
||||
end := offset + protocol.ByteCount(len(data))
|
||||
|
||||
if end <= s.gaps.Front().Value.Start {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
startGap, startsInGap := s.findStartGap(start)
|
||||
endGap, endsInGap := s.findEndGap(startGap, end)
|
||||
|
||||
startGapEqualsEndGap := startGap == endGap
|
||||
|
||||
if (startGapEqualsEndGap && end <= startGap.Value.Start) ||
|
||||
(!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
startGapNext := startGap.Next()
|
||||
startGapEnd := startGap.Value.End // save it, in case startGap is modified
|
||||
endGapStart := endGap.Value.Start // save it, in case endGap is modified
|
||||
endGapEnd := endGap.Value.End // save it, in case endGap is modified
|
||||
var adjustedStartGapEnd bool
|
||||
var wasCut bool
|
||||
|
||||
pos := start
|
||||
var hasReplacedAtLeastOne bool
|
||||
for {
|
||||
oldEntry, ok := s.queue[pos]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
|
||||
if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) {
|
||||
// The existing frame is shorter than the new frame. Replace it.
|
||||
delete(s.queue, pos)
|
||||
pos += oldEntryLen
|
||||
hasReplacedAtLeastOne = true
|
||||
if oldEntry.DoneCb != nil {
|
||||
oldEntry.DoneCb()
|
||||
}
|
||||
} else {
|
||||
if !hasReplacedAtLeastOne {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
// The existing frame is longer than the new frame.
|
||||
// Cut the new frame such that the end aligns with the start of the existing frame.
|
||||
data = data[:pos-start]
|
||||
end = pos
|
||||
wasCut = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !startsInGap && !hasReplacedAtLeastOne {
|
||||
// cut the frame, such that it starts at the start of the gap
|
||||
data = data[startGap.Value.Start-start:]
|
||||
start = startGap.Value.Start
|
||||
wasCut = true
|
||||
}
|
||||
if start <= startGap.Value.Start {
|
||||
if end >= startGap.Value.End {
|
||||
// The frame covers the whole startGap. Delete the gap.
|
||||
s.gaps.Remove(startGap)
|
||||
} else {
|
||||
startGap.Value.Start = end
|
||||
}
|
||||
} else if !hasReplacedAtLeastOne {
|
||||
startGap.Value.End = start
|
||||
adjustedStartGapEnd = true
|
||||
}
|
||||
|
||||
if !startGapEqualsEndGap {
|
||||
s.deleteConsecutive(startGapEnd)
|
||||
var nextGap *list.Element[byteInterval]
|
||||
for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap {
|
||||
nextGap = gap.Next()
|
||||
s.deleteConsecutive(gap.Value.End)
|
||||
s.gaps.Remove(gap)
|
||||
}
|
||||
}
|
||||
|
||||
if !endsInGap && start != endGapEnd && end > endGapEnd {
|
||||
// cut the frame, such that it ends at the end of the gap
|
||||
data = data[:endGapEnd-start]
|
||||
end = endGapEnd
|
||||
wasCut = true
|
||||
}
|
||||
if end == endGapEnd {
|
||||
if !startGapEqualsEndGap {
|
||||
// The frame covers the whole endGap. Delete the gap.
|
||||
s.gaps.Remove(endGap)
|
||||
}
|
||||
} else {
|
||||
if startGapEqualsEndGap && adjustedStartGapEnd {
|
||||
// The frame split the existing gap into two.
|
||||
s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap)
|
||||
} else if !startGapEqualsEndGap {
|
||||
endGap.Value.Start = end
|
||||
}
|
||||
}
|
||||
|
||||
if wasCut && len(data) < protocol.MinStreamFrameBufferSize {
|
||||
newData := make([]byte, len(data))
|
||||
copy(newData, data)
|
||||
data = newData
|
||||
if doneCb != nil {
|
||||
doneCb()
|
||||
doneCb = nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
|
||||
return errors.New("too many gaps in received data")
|
||||
}
|
||||
|
||||
s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
|
||||
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
|
||||
if offset >= gap.Value.Start && offset <= gap.Value.End {
|
||||
return gap, true
|
||||
}
|
||||
if offset < gap.Value.Start {
|
||||
return gap, false
|
||||
}
|
||||
}
|
||||
panic("no gap found")
|
||||
}
|
||||
|
||||
func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
|
||||
for gap := startGap; gap != nil; gap = gap.Next() {
|
||||
if offset >= gap.Value.Start && offset < gap.Value.End {
|
||||
return gap, true
|
||||
}
|
||||
if offset < gap.Value.Start {
|
||||
return gap.Prev(), false
|
||||
}
|
||||
}
|
||||
panic("no gap found")
|
||||
}
|
||||
|
||||
// deleteConsecutive deletes consecutive frames from the queue, starting at pos
|
||||
func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) {
|
||||
for {
|
||||
oldEntry, ok := s.queue[pos]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
|
||||
delete(s.queue, pos)
|
||||
if oldEntry.DoneCb != nil {
|
||||
oldEntry.DoneCb()
|
||||
}
|
||||
pos += oldEntryLen
|
||||
}
|
||||
}
|
||||
|
||||
func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) {
|
||||
entry, ok := s.queue[s.readPos]
|
||||
if !ok {
|
||||
return s.readPos, nil, nil
|
||||
}
|
||||
delete(s.queue, s.readPos)
|
||||
offset := s.readPos
|
||||
s.readPos += protocol.ByteCount(len(entry.Data))
|
||||
if s.gaps.Front().Value.End <= s.readPos {
|
||||
panic("frame sorter BUG: read position higher than a gap")
|
||||
}
|
||||
return offset, entry.Data, entry.DoneCb
|
||||
}
|
||||
|
||||
// HasMoreData says if there is any more data queued at *any* offset.
|
||||
func (s *frameSorter) HasMoreData() bool {
|
||||
return len(s.queue) > 0
|
||||
}
|
168
vendor/github.com/quic-go/quic-go/framer.go
generated
vendored
Normal file
168
vendor/github.com/quic-go/quic-go/framer.go
generated
vendored
Normal file
|
@ -0,0 +1,168 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/ackhandler"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
type framer interface {
|
||||
HasData() bool
|
||||
|
||||
QueueControlFrame(wire.Frame)
|
||||
AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount)
|
||||
|
||||
AddActiveStream(protocol.StreamID)
|
||||
AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount)
|
||||
|
||||
Handle0RTTRejection() error
|
||||
}
|
||||
|
||||
type framerI struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
streamGetter streamGetter
|
||||
|
||||
activeStreams map[protocol.StreamID]struct{}
|
||||
streamQueue []protocol.StreamID
|
||||
|
||||
controlFrameMutex sync.Mutex
|
||||
controlFrames []wire.Frame
|
||||
}
|
||||
|
||||
var _ framer = &framerI{}
|
||||
|
||||
func newFramer(streamGetter streamGetter) framer {
|
||||
return &framerI{
|
||||
streamGetter: streamGetter,
|
||||
activeStreams: make(map[protocol.StreamID]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *framerI) HasData() bool {
|
||||
f.mutex.Lock()
|
||||
hasData := len(f.streamQueue) > 0
|
||||
f.mutex.Unlock()
|
||||
if hasData {
|
||||
return true
|
||||
}
|
||||
f.controlFrameMutex.Lock()
|
||||
hasData = len(f.controlFrames) > 0
|
||||
f.controlFrameMutex.Unlock()
|
||||
return hasData
|
||||
}
|
||||
|
||||
func (f *framerI) QueueControlFrame(frame wire.Frame) {
|
||||
f.controlFrameMutex.Lock()
|
||||
f.controlFrames = append(f.controlFrames, frame)
|
||||
f.controlFrameMutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) {
|
||||
var length protocol.ByteCount
|
||||
f.controlFrameMutex.Lock()
|
||||
for len(f.controlFrames) > 0 {
|
||||
frame := f.controlFrames[len(f.controlFrames)-1]
|
||||
frameLen := frame.Length(v)
|
||||
if length+frameLen > maxLen {
|
||||
break
|
||||
}
|
||||
af := ackhandler.GetFrame()
|
||||
af.Frame = frame
|
||||
frames = append(frames, af)
|
||||
length += frameLen
|
||||
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
|
||||
}
|
||||
f.controlFrameMutex.Unlock()
|
||||
return frames, length
|
||||
}
|
||||
|
||||
func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
||||
f.mutex.Lock()
|
||||
if _, ok := f.activeStreams[id]; !ok {
|
||||
f.streamQueue = append(f.streamQueue, id)
|
||||
f.activeStreams[id] = struct{}{}
|
||||
}
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) {
|
||||
var length protocol.ByteCount
|
||||
var lastFrame *ackhandler.Frame
|
||||
f.mutex.Lock()
|
||||
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
|
||||
numActiveStreams := len(f.streamQueue)
|
||||
for i := 0; i < numActiveStreams; i++ {
|
||||
if protocol.MinStreamFrameSize+length > maxLen {
|
||||
break
|
||||
}
|
||||
id := f.streamQueue[0]
|
||||
f.streamQueue = f.streamQueue[1:]
|
||||
// This should never return an error. Better check it anyway.
|
||||
// The stream will only be in the streamQueue, if it enqueued itself there.
|
||||
str, err := f.streamGetter.GetOrOpenSendStream(id)
|
||||
// The stream can be nil if it completed after it said it had data.
|
||||
if str == nil || err != nil {
|
||||
delete(f.activeStreams, id)
|
||||
continue
|
||||
}
|
||||
remainingLen := maxLen - length
|
||||
// For the last STREAM frame, we'll remove the DataLen field later.
|
||||
// Therefore, we can pretend to have more bytes available when popping
|
||||
// the STREAM frame (which will always have the DataLen set).
|
||||
remainingLen += quicvarint.Len(uint64(remainingLen))
|
||||
frame, hasMoreData := str.popStreamFrame(remainingLen, v)
|
||||
if hasMoreData { // put the stream back in the queue (at the end)
|
||||
f.streamQueue = append(f.streamQueue, id)
|
||||
} else { // no more data to send. Stream is not active any more
|
||||
delete(f.activeStreams, id)
|
||||
}
|
||||
// The frame can be nil
|
||||
// * if the receiveStream was canceled after it said it had data
|
||||
// * the remaining size doesn't allow us to add another STREAM frame
|
||||
if frame == nil {
|
||||
continue
|
||||
}
|
||||
frames = append(frames, frame)
|
||||
length += frame.Length(v)
|
||||
lastFrame = frame
|
||||
}
|
||||
f.mutex.Unlock()
|
||||
if lastFrame != nil {
|
||||
lastFrameLen := lastFrame.Length(v)
|
||||
// account for the smaller size of the last STREAM frame
|
||||
lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false
|
||||
length += lastFrame.Length(v) - lastFrameLen
|
||||
}
|
||||
return frames, length
|
||||
}
|
||||
|
||||
func (f *framerI) Handle0RTTRejection() error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
f.controlFrameMutex.Lock()
|
||||
f.streamQueue = f.streamQueue[:0]
|
||||
for id := range f.activeStreams {
|
||||
delete(f.activeStreams, id)
|
||||
}
|
||||
var j int
|
||||
for i, frame := range f.controlFrames {
|
||||
switch frame.(type) {
|
||||
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame:
|
||||
return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT")
|
||||
case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
|
||||
continue
|
||||
default:
|
||||
f.controlFrames[j] = f.controlFrames[i]
|
||||
j++
|
||||
}
|
||||
}
|
||||
f.controlFrames = f.controlFrames[:j]
|
||||
f.controlFrameMutex.Unlock()
|
||||
return nil
|
||||
}
|
135
vendor/github.com/quic-go/quic-go/http3/body.go
generated
vendored
Normal file
135
vendor/github.com/quic-go/quic-go/http3/body.go
generated
vendored
Normal file
|
@ -0,0 +1,135 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by:
|
||||
// * for the server: the http.Request.Body
|
||||
// * for the client: the http.Response.Body
|
||||
// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
|
||||
// When a stream is taken over, it's the caller's responsibility to close the stream.
|
||||
type HTTPStreamer interface {
|
||||
HTTPStream() Stream
|
||||
}
|
||||
|
||||
type StreamCreator interface {
|
||||
// Context returns a context that is cancelled when the underlying connection is closed.
|
||||
Context() context.Context
|
||||
OpenStream() (quic.Stream, error)
|
||||
OpenStreamSync(context.Context) (quic.Stream, error)
|
||||
OpenUniStream() (quic.SendStream, error)
|
||||
OpenUniStreamSync(context.Context) (quic.SendStream, error)
|
||||
LocalAddr() net.Addr
|
||||
RemoteAddr() net.Addr
|
||||
ConnectionState() quic.ConnectionState
|
||||
}
|
||||
|
||||
var _ StreamCreator = quic.Connection(nil)
|
||||
|
||||
// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
|
||||
// It is used by WebTransport to create WebTransport streams after a session has been established.
|
||||
type Hijacker interface {
|
||||
StreamCreator() StreamCreator
|
||||
}
|
||||
|
||||
// The body of a http.Request or http.Response.
|
||||
type body struct {
|
||||
str quic.Stream
|
||||
|
||||
wasHijacked bool // set when HTTPStream is called
|
||||
}
|
||||
|
||||
var (
|
||||
_ io.ReadCloser = &body{}
|
||||
_ HTTPStreamer = &body{}
|
||||
)
|
||||
|
||||
func newRequestBody(str Stream) *body {
|
||||
return &body{str: str}
|
||||
}
|
||||
|
||||
func (r *body) HTTPStream() Stream {
|
||||
r.wasHijacked = true
|
||||
return r.str
|
||||
}
|
||||
|
||||
func (r *body) wasStreamHijacked() bool {
|
||||
return r.wasHijacked
|
||||
}
|
||||
|
||||
func (r *body) Read(b []byte) (int, error) {
|
||||
return r.str.Read(b)
|
||||
}
|
||||
|
||||
func (r *body) Close() error {
|
||||
r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
|
||||
return nil
|
||||
}
|
||||
|
||||
type hijackableBody struct {
|
||||
body
|
||||
conn quic.Connection // only needed to implement Hijacker
|
||||
|
||||
// only set for the http.Response
|
||||
// The channel is closed when the user is done with this response:
|
||||
// either when Read() errors, or when Close() is called.
|
||||
reqDone chan<- struct{}
|
||||
reqDoneClosed bool
|
||||
}
|
||||
|
||||
var (
|
||||
_ Hijacker = &hijackableBody{}
|
||||
_ HTTPStreamer = &hijackableBody{}
|
||||
)
|
||||
|
||||
func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody {
|
||||
return &hijackableBody{
|
||||
body: body{
|
||||
str: str,
|
||||
},
|
||||
reqDone: done,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *hijackableBody) StreamCreator() StreamCreator {
|
||||
return r.conn
|
||||
}
|
||||
|
||||
func (r *hijackableBody) Read(b []byte) (int, error) {
|
||||
n, err := r.str.Read(b)
|
||||
if err != nil {
|
||||
r.requestDone()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *hijackableBody) requestDone() {
|
||||
if r.reqDoneClosed || r.reqDone == nil {
|
||||
return
|
||||
}
|
||||
if r.reqDone != nil {
|
||||
close(r.reqDone)
|
||||
}
|
||||
r.reqDoneClosed = true
|
||||
}
|
||||
|
||||
func (r *body) StreamID() quic.StreamID {
|
||||
return r.str.StreamID()
|
||||
}
|
||||
|
||||
func (r *hijackableBody) Close() error {
|
||||
r.requestDone()
|
||||
// If the EOF was read, CancelRead() is a no-op.
|
||||
r.str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *hijackableBody) HTTPStream() Stream {
|
||||
return r.str
|
||||
}
|
55
vendor/github.com/quic-go/quic-go/http3/capsule.go
generated
vendored
Normal file
55
vendor/github.com/quic-go/quic-go/http3/capsule.go
generated
vendored
Normal file
|
@ -0,0 +1,55 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// CapsuleType is the type of the capsule.
|
||||
type CapsuleType uint64
|
||||
|
||||
type exactReader struct {
|
||||
R *io.LimitedReader
|
||||
}
|
||||
|
||||
func (r *exactReader) Read(b []byte) (int, error) {
|
||||
n, err := r.R.Read(b)
|
||||
if r.R.N > 0 {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ParseCapsule parses the header of a Capsule.
|
||||
// It returns an io.LimitedReader that can be used to read the Capsule value.
|
||||
// The Capsule value must be read entirely (i.e. until the io.EOF) before using r again.
|
||||
func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
|
||||
ct, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return 0, nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
l, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return 0, nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
return CapsuleType(ct), &exactReader{R: io.LimitReader(r, int64(l)).(*io.LimitedReader)}, nil
|
||||
}
|
||||
|
||||
// WriteCapsule writes a capsule
|
||||
func WriteCapsule(w quicvarint.Writer, ct CapsuleType, value []byte) error {
|
||||
b := make([]byte, 0, 16)
|
||||
b = quicvarint.Append(b, uint64(ct))
|
||||
b = quicvarint.Append(b, uint64(len(value)))
|
||||
if _, err := w.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.Write(value)
|
||||
return err
|
||||
}
|
450
vendor/github.com/quic-go/quic-go/http3/client.go
generated
vendored
Normal file
450
vendor/github.com/quic-go/quic-go/http3/client.go
generated
vendored
Normal file
|
@ -0,0 +1,450 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
// MethodGet0RTT allows a GET request to be sent using 0-RTT.
|
||||
// Note that 0-RTT data doesn't provide replay protection.
|
||||
const MethodGet0RTT = "GET_0RTT"
|
||||
|
||||
const (
|
||||
defaultUserAgent = "quic-go HTTP/3"
|
||||
defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
|
||||
)
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
|
||||
KeepAlivePeriod: 10 * time.Second,
|
||||
Versions: []protocol.VersionNumber{protocol.VersionTLS},
|
||||
}
|
||||
|
||||
type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
|
||||
|
||||
var dialAddr = quic.DialAddrEarlyContext
|
||||
|
||||
type roundTripperOpts struct {
|
||||
DisableCompression bool
|
||||
EnableDatagram bool
|
||||
MaxHeaderBytes int64
|
||||
AdditionalSettings map[uint64]uint64
|
||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||
}
|
||||
|
||||
// client is a HTTP3 client doing requests
|
||||
type client struct {
|
||||
tlsConf *tls.Config
|
||||
config *quic.Config
|
||||
opts *roundTripperOpts
|
||||
|
||||
dialOnce sync.Once
|
||||
dialer dialFunc
|
||||
handshakeErr error
|
||||
|
||||
requestWriter *requestWriter
|
||||
|
||||
decoder *qpack.Decoder
|
||||
|
||||
hostname string
|
||||
conn quic.EarlyConnection
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ roundTripCloser = &client{}
|
||||
|
||||
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
||||
if conf == nil {
|
||||
conf = defaultQuicConfig.Clone()
|
||||
} else if len(conf.Versions) == 0 {
|
||||
conf = conf.Clone()
|
||||
conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
|
||||
}
|
||||
if len(conf.Versions) != 1 {
|
||||
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
|
||||
}
|
||||
if conf.MaxIncomingStreams == 0 {
|
||||
conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
||||
}
|
||||
conf.EnableDatagrams = opts.EnableDatagram
|
||||
logger := utils.DefaultLogger.WithPrefix("h3 client")
|
||||
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
tlsConf = tlsConf.Clone()
|
||||
}
|
||||
// Replace existing ALPNs by H3
|
||||
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
|
||||
|
||||
return &client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
tlsConf: tlsConf,
|
||||
requestWriter: newRequestWriter(logger),
|
||||
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
|
||||
config: conf,
|
||||
opts: opts,
|
||||
dialer: dialer,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
var err error
|
||||
if c.dialer != nil {
|
||||
c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
|
||||
} else {
|
||||
c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send the SETTINGs frame, using 0-RTT data, if possible
|
||||
go func() {
|
||||
if err := c.setupConn(); err != nil {
|
||||
c.logger.Debugf("Setting up connection failed: %s", err)
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
|
||||
}
|
||||
}()
|
||||
|
||||
if c.opts.StreamHijacker != nil {
|
||||
go c.handleBidirectionalStreams()
|
||||
}
|
||||
go c.handleUnidirectionalStreams()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) setupConn() error {
|
||||
// open the control stream
|
||||
str, err := c.conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b := make([]byte, 0, 64)
|
||||
b = quicvarint.Append(b, streamTypeControlStream)
|
||||
// send the SETTINGS frame
|
||||
b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b)
|
||||
_, err = str.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) handleBidirectionalStreams() {
|
||||
for {
|
||||
str, err := c.conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
c.logger.Debugf("accepting bidirectional stream failed: %s", err)
|
||||
return
|
||||
}
|
||||
go func(str quic.Stream) {
|
||||
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
|
||||
return c.opts.StreamHijacker(ft, c.conn, str, e)
|
||||
})
|
||||
if err == errHijacked {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
c.logger.Debugf("error handling stream: %s", err)
|
||||
}
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handleUnidirectionalStreams() {
|
||||
for {
|
||||
str, err := c.conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
c.logger.Debugf("accepting unidirectional stream failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func(str quic.ReceiveStream) {
|
||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||
if err != nil {
|
||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
|
||||
return
|
||||
}
|
||||
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||
return
|
||||
}
|
||||
// We're only interested in the control stream here.
|
||||
switch streamType {
|
||||
case streamTypeControlStream:
|
||||
case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
|
||||
// Our QPACK implementation doesn't use the dynamic table yet.
|
||||
// TODO: check that only one stream of each type is opened.
|
||||
return
|
||||
case streamTypePushStream:
|
||||
// We never increased the Push ID, so we don't expect any push streams.
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
|
||||
return
|
||||
default:
|
||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
|
||||
return
|
||||
}
|
||||
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||
return
|
||||
}
|
||||
f, err := parseNextFrame(str, nil)
|
||||
if err != nil {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
|
||||
return
|
||||
}
|
||||
sf, ok := f.(*settingsFrame)
|
||||
if !ok {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
|
||||
return
|
||||
}
|
||||
if !sf.Datagram {
|
||||
return
|
||||
}
|
||||
// If datagram support was enabled on our side as well as on the server side,
|
||||
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
|
||||
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
|
||||
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
|
||||
}
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
if c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
|
||||
}
|
||||
|
||||
func (c *client) maxHeaderBytes() uint64 {
|
||||
if c.opts.MaxHeaderBytes <= 0 {
|
||||
return defaultMaxResponseHeaderBytes
|
||||
}
|
||||
return uint64(c.opts.MaxHeaderBytes)
|
||||
}
|
||||
|
||||
// RoundTripOpt executes a request and returns a response
|
||||
func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
||||
return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||
}
|
||||
|
||||
c.dialOnce.Do(func() {
|
||||
c.handshakeErr = c.dial(req.Context())
|
||||
})
|
||||
|
||||
if c.handshakeErr != nil {
|
||||
return nil, c.handshakeErr
|
||||
}
|
||||
|
||||
// Immediately send out this request, if this is a 0-RTT request.
|
||||
if req.Method == MethodGet0RTT {
|
||||
req.Method = http.MethodGet
|
||||
} else {
|
||||
// wait for the handshake to complete
|
||||
select {
|
||||
case <-c.conn.HandshakeComplete().Done():
|
||||
case <-req.Context().Done():
|
||||
return nil, req.Context().Err()
|
||||
}
|
||||
}
|
||||
|
||||
str, err := c.conn.OpenStreamSync(req.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Request Cancellation:
|
||||
// This go routine keeps running even after RoundTripOpt() returns.
|
||||
// It is shut down when the application is done processing the body.
|
||||
reqDone := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
|
||||
str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
|
||||
case <-reqDone:
|
||||
}
|
||||
}()
|
||||
|
||||
doneChan := reqDone
|
||||
if opt.DontCloseRequestStream {
|
||||
doneChan = nil
|
||||
}
|
||||
rsp, rerr := c.doRequest(req, str, opt, doneChan)
|
||||
if rerr.err != nil { // if any error occurred
|
||||
close(reqDone)
|
||||
<-done
|
||||
if rerr.streamErr != 0 { // if it was a stream error
|
||||
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
|
||||
}
|
||||
if rerr.connErr != 0 { // if it was a connection error
|
||||
var reason string
|
||||
if rerr.err != nil {
|
||||
reason = rerr.err.Error()
|
||||
}
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
|
||||
}
|
||||
return nil, rerr.err
|
||||
}
|
||||
if opt.DontCloseRequestStream {
|
||||
close(reqDone)
|
||||
<-done
|
||||
}
|
||||
return rsp, rerr.err
|
||||
}
|
||||
|
||||
func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
|
||||
defer body.Close()
|
||||
b := make([]byte, bodyCopyBufferSize)
|
||||
for {
|
||||
n, rerr := body.Read(b)
|
||||
if n == 0 {
|
||||
if rerr == nil {
|
||||
continue
|
||||
}
|
||||
if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
if _, err := str.Write(b[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
if rerr != nil {
|
||||
if rerr == io.EOF {
|
||||
break
|
||||
}
|
||||
str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
|
||||
return rerr
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
|
||||
var requestGzip bool
|
||||
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
||||
requestGzip = true
|
||||
}
|
||||
if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
|
||||
return nil, newStreamError(errorInternalError, err)
|
||||
}
|
||||
|
||||
if req.Body == nil && !opt.DontCloseRequestStream {
|
||||
str.Close()
|
||||
}
|
||||
|
||||
hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
|
||||
if req.Body != nil {
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
if err := c.sendRequestBody(hstr, req.Body); err != nil {
|
||||
c.logger.Errorf("Error writing request: %s", err)
|
||||
}
|
||||
if !opt.DontCloseRequestStream {
|
||||
hstr.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
frame, err := parseNextFrame(str, nil)
|
||||
if err != nil {
|
||||
return nil, newStreamError(errorFrameError, err)
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
|
||||
}
|
||||
if hf.Length > c.maxHeaderBytes() {
|
||||
return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
|
||||
}
|
||||
headerBlock := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||
return nil, newStreamError(errorRequestIncomplete, err)
|
||||
}
|
||||
hfs, err := c.decoder.DecodeFull(headerBlock)
|
||||
if err != nil {
|
||||
// TODO: use the right error code
|
||||
return nil, newConnError(errorGeneralProtocolError, err)
|
||||
}
|
||||
|
||||
connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS)
|
||||
res := &http.Response{
|
||||
Proto: "HTTP/3.0",
|
||||
ProtoMajor: 3,
|
||||
Header: http.Header{},
|
||||
TLS: &connState,
|
||||
Request: req,
|
||||
}
|
||||
for _, hf := range hfs {
|
||||
switch hf.Name {
|
||||
case ":status":
|
||||
status, err := strconv.Atoi(hf.Value)
|
||||
if err != nil {
|
||||
return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
|
||||
}
|
||||
res.StatusCode = status
|
||||
res.Status = hf.Value + " " + http.StatusText(status)
|
||||
default:
|
||||
res.Header.Add(hf.Name, hf.Value)
|
||||
}
|
||||
}
|
||||
respBody := newResponseBody(hstr, c.conn, reqDone)
|
||||
|
||||
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
|
||||
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
|
||||
isInformational := res.StatusCode >= 100 && res.StatusCode < 200
|
||||
isNoContent := res.StatusCode == http.StatusNoContent
|
||||
isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
|
||||
if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
|
||||
res.ContentLength = -1
|
||||
if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
|
||||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
|
||||
res.ContentLength = clen64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
||||
res.Header.Del("Content-Encoding")
|
||||
res.Header.Del("Content-Length")
|
||||
res.ContentLength = -1
|
||||
res.Body = newGzipReader(respBody)
|
||||
res.Uncompressed = true
|
||||
} else {
|
||||
res.Body = respBody
|
||||
}
|
||||
|
||||
return res, requestError{}
|
||||
}
|
||||
|
||||
func (c *client) HandshakeComplete() bool {
|
||||
if c.conn == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-c.conn.HandshakeComplete().Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
73
vendor/github.com/quic-go/quic-go/http3/error_codes.go
generated
vendored
Normal file
73
vendor/github.com/quic-go/quic-go/http3/error_codes.go
generated
vendored
Normal file
|
@ -0,0 +1,73 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type errorCode quic.ApplicationErrorCode
|
||||
|
||||
const (
|
||||
errorNoError errorCode = 0x100
|
||||
errorGeneralProtocolError errorCode = 0x101
|
||||
errorInternalError errorCode = 0x102
|
||||
errorStreamCreationError errorCode = 0x103
|
||||
errorClosedCriticalStream errorCode = 0x104
|
||||
errorFrameUnexpected errorCode = 0x105
|
||||
errorFrameError errorCode = 0x106
|
||||
errorExcessiveLoad errorCode = 0x107
|
||||
errorIDError errorCode = 0x108
|
||||
errorSettingsError errorCode = 0x109
|
||||
errorMissingSettings errorCode = 0x10a
|
||||
errorRequestRejected errorCode = 0x10b
|
||||
errorRequestCanceled errorCode = 0x10c
|
||||
errorRequestIncomplete errorCode = 0x10d
|
||||
errorMessageError errorCode = 0x10e
|
||||
errorConnectError errorCode = 0x10f
|
||||
errorVersionFallback errorCode = 0x110
|
||||
errorDatagramError errorCode = 0x4a1268
|
||||
)
|
||||
|
||||
func (e errorCode) String() string {
|
||||
switch e {
|
||||
case errorNoError:
|
||||
return "H3_NO_ERROR"
|
||||
case errorGeneralProtocolError:
|
||||
return "H3_GENERAL_PROTOCOL_ERROR"
|
||||
case errorInternalError:
|
||||
return "H3_INTERNAL_ERROR"
|
||||
case errorStreamCreationError:
|
||||
return "H3_STREAM_CREATION_ERROR"
|
||||
case errorClosedCriticalStream:
|
||||
return "H3_CLOSED_CRITICAL_STREAM"
|
||||
case errorFrameUnexpected:
|
||||
return "H3_FRAME_UNEXPECTED"
|
||||
case errorFrameError:
|
||||
return "H3_FRAME_ERROR"
|
||||
case errorExcessiveLoad:
|
||||
return "H3_EXCESSIVE_LOAD"
|
||||
case errorIDError:
|
||||
return "H3_ID_ERROR"
|
||||
case errorSettingsError:
|
||||
return "H3_SETTINGS_ERROR"
|
||||
case errorMissingSettings:
|
||||
return "H3_MISSING_SETTINGS"
|
||||
case errorRequestRejected:
|
||||
return "H3_REQUEST_REJECTED"
|
||||
case errorRequestCanceled:
|
||||
return "H3_REQUEST_CANCELLED"
|
||||
case errorRequestIncomplete:
|
||||
return "H3_INCOMPLETE_REQUEST"
|
||||
case errorMessageError:
|
||||
return "H3_MESSAGE_ERROR"
|
||||
case errorConnectError:
|
||||
return "H3_CONNECT_ERROR"
|
||||
case errorVersionFallback:
|
||||
return "H3_VERSION_FALLBACK"
|
||||
case errorDatagramError:
|
||||
return "H3_DATAGRAM_ERROR"
|
||||
default:
|
||||
return fmt.Sprintf("unknown error code: %#x", uint16(e))
|
||||
}
|
||||
}
|
164
vendor/github.com/quic-go/quic-go/http3/frames.go
generated
vendored
Normal file
164
vendor/github.com/quic-go/quic-go/http3/frames.go
generated
vendored
Normal file
|
@ -0,0 +1,164 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// FrameType is the frame type of a HTTP/3 frame
|
||||
type FrameType uint64
|
||||
|
||||
type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error)
|
||||
|
||||
type frame interface{}
|
||||
|
||||
var errHijacked = errors.New("hijacked")
|
||||
|
||||
func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) {
|
||||
qr := quicvarint.NewReader(r)
|
||||
for {
|
||||
t, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
if unknownFrameHandler != nil {
|
||||
hijacked, err := unknownFrameHandler(0, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hijacked {
|
||||
return nil, errHijacked
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
|
||||
if t > 0xd && unknownFrameHandler != nil {
|
||||
hijacked, err := unknownFrameHandler(FrameType(t), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hijacked {
|
||||
return nil, errHijacked
|
||||
}
|
||||
// If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it.
|
||||
}
|
||||
l, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case 0x0:
|
||||
return &dataFrame{Length: l}, nil
|
||||
case 0x1:
|
||||
return &headersFrame{Length: l}, nil
|
||||
case 0x4:
|
||||
return parseSettingsFrame(r, l)
|
||||
case 0x3: // CANCEL_PUSH
|
||||
case 0x5: // PUSH_PROMISE
|
||||
case 0x7: // GOAWAY
|
||||
case 0xd: // MAX_PUSH_ID
|
||||
}
|
||||
// skip over unknown frames
|
||||
if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type dataFrame struct {
|
||||
Length uint64
|
||||
}
|
||||
|
||||
func (f *dataFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x0)
|
||||
return quicvarint.Append(b, f.Length)
|
||||
}
|
||||
|
||||
type headersFrame struct {
|
||||
Length uint64
|
||||
}
|
||||
|
||||
func (f *headersFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x1)
|
||||
return quicvarint.Append(b, f.Length)
|
||||
}
|
||||
|
||||
const settingDatagram = 0xffd277
|
||||
|
||||
type settingsFrame struct {
|
||||
Datagram bool
|
||||
Other map[uint64]uint64 // all settings that we don't explicitly recognize
|
||||
}
|
||||
|
||||
func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
|
||||
if l > 8*(1<<10) {
|
||||
return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l)
|
||||
}
|
||||
buf := make([]byte, l)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
frame := &settingsFrame{}
|
||||
b := bytes.NewReader(buf)
|
||||
var readDatagram bool
|
||||
for b.Len() > 0 {
|
||||
id, err := quicvarint.Read(b)
|
||||
if err != nil { // should not happen. We allocated the whole frame already.
|
||||
return nil, err
|
||||
}
|
||||
val, err := quicvarint.Read(b)
|
||||
if err != nil { // should not happen. We allocated the whole frame already.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch id {
|
||||
case settingDatagram:
|
||||
if readDatagram {
|
||||
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
readDatagram = true
|
||||
if val != 0 && val != 1 {
|
||||
return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val)
|
||||
}
|
||||
frame.Datagram = val == 1
|
||||
default:
|
||||
if _, ok := frame.Other[id]; ok {
|
||||
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
if frame.Other == nil {
|
||||
frame.Other = make(map[uint64]uint64)
|
||||
}
|
||||
frame.Other[id] = val
|
||||
}
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *settingsFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x4)
|
||||
var l protocol.ByteCount
|
||||
for id, val := range f.Other {
|
||||
l += quicvarint.Len(id) + quicvarint.Len(val)
|
||||
}
|
||||
if f.Datagram {
|
||||
l += quicvarint.Len(settingDatagram) + quicvarint.Len(1)
|
||||
}
|
||||
b = quicvarint.Append(b, uint64(l))
|
||||
if f.Datagram {
|
||||
b = quicvarint.Append(b, settingDatagram)
|
||||
b = quicvarint.Append(b, 1)
|
||||
}
|
||||
for id, val := range f.Other {
|
||||
b = quicvarint.Append(b, id)
|
||||
b = quicvarint.Append(b, val)
|
||||
}
|
||||
return b
|
||||
}
|
39
vendor/github.com/quic-go/quic-go/http3/gzip_reader.go
generated
vendored
Normal file
39
vendor/github.com/quic-go/quic-go/http3/gzip_reader.go
generated
vendored
Normal file
|
@ -0,0 +1,39 @@
|
|||
package http3
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
// gzipReader wraps a response body so it can lazily
|
||||
// call gzip.NewReader on the first call to Read
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
)
|
||||
|
||||
// call gzip.NewReader on the first call to Read
|
||||
type gzipReader struct {
|
||||
body io.ReadCloser // underlying Response.Body
|
||||
zr *gzip.Reader // lazily-initialized gzip reader
|
||||
zerr error // sticky error
|
||||
}
|
||||
|
||||
func newGzipReader(body io.ReadCloser) io.ReadCloser {
|
||||
return &gzipReader{body: body}
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Read(p []byte) (n int, err error) {
|
||||
if gz.zerr != nil {
|
||||
return 0, gz.zerr
|
||||
}
|
||||
if gz.zr == nil {
|
||||
gz.zr, err = gzip.NewReader(gz.body)
|
||||
if err != nil {
|
||||
gz.zerr = err
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return gz.zr.Read(p)
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Close() error {
|
||||
return gz.body.Close()
|
||||
}
|
76
vendor/github.com/quic-go/quic-go/http3/http_stream.go
generated
vendored
Normal file
76
vendor/github.com/quic-go/quic-go/http3/http_stream.go
generated
vendored
Normal file
|
@ -0,0 +1,76 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// A Stream is a HTTP/3 stream.
|
||||
// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
|
||||
type Stream quic.Stream
|
||||
|
||||
// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
|
||||
// from the QUIC stream, it writes to and reads from the HTTP stream.
|
||||
type stream struct {
|
||||
quic.Stream
|
||||
|
||||
buf []byte
|
||||
|
||||
onFrameError func()
|
||||
bytesRemainingInFrame uint64
|
||||
}
|
||||
|
||||
var _ Stream = &stream{}
|
||||
|
||||
func newStream(str quic.Stream, onFrameError func()) *stream {
|
||||
return &stream{
|
||||
Stream: str,
|
||||
onFrameError: onFrameError,
|
||||
buf: make([]byte, 0, 16),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stream) Read(b []byte) (int, error) {
|
||||
if s.bytesRemainingInFrame == 0 {
|
||||
parseLoop:
|
||||
for {
|
||||
frame, err := parseNextFrame(s.Stream, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch f := frame.(type) {
|
||||
case *headersFrame:
|
||||
// skip HEADERS frames
|
||||
continue
|
||||
case *dataFrame:
|
||||
s.bytesRemainingInFrame = f.Length
|
||||
break parseLoop
|
||||
default:
|
||||
s.onFrameError()
|
||||
// parseNextFrame skips over unknown frame types
|
||||
// Therefore, this condition is only entered when we parsed another known frame type.
|
||||
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var n int
|
||||
var err error
|
||||
if s.bytesRemainingInFrame < uint64(len(b)) {
|
||||
n, err = s.Stream.Read(b[:s.bytesRemainingInFrame])
|
||||
} else {
|
||||
n, err = s.Stream.Read(b)
|
||||
}
|
||||
s.bytesRemainingInFrame -= uint64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (int, error) {
|
||||
s.buf = s.buf[:0]
|
||||
s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf)
|
||||
if _, err := s.Stream.Write(s.buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.Stream.Write(b)
|
||||
}
|
111
vendor/github.com/quic-go/quic-go/http3/request.go
generated
vendored
Normal file
111
vendor/github.com/quic-go/quic-go/http3/request.go
generated
vendored
Normal file
|
@ -0,0 +1,111 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) {
|
||||
var path, authority, method, protocol, scheme, contentLengthStr string
|
||||
|
||||
httpHeaders := http.Header{}
|
||||
for _, h := range headers {
|
||||
switch h.Name {
|
||||
case ":path":
|
||||
path = h.Value
|
||||
case ":method":
|
||||
method = h.Value
|
||||
case ":authority":
|
||||
authority = h.Value
|
||||
case ":protocol":
|
||||
protocol = h.Value
|
||||
case ":scheme":
|
||||
scheme = h.Value
|
||||
case "content-length":
|
||||
contentLengthStr = h.Value
|
||||
default:
|
||||
if !h.IsPseudo() {
|
||||
httpHeaders.Add(h.Name, h.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
|
||||
if len(httpHeaders["Cookie"]) > 0 {
|
||||
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
|
||||
}
|
||||
|
||||
isConnect := method == http.MethodConnect
|
||||
// Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4
|
||||
isExtendedConnected := isConnect && protocol != ""
|
||||
if isExtendedConnected {
|
||||
if scheme == "" || path == "" || authority == "" {
|
||||
return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty")
|
||||
}
|
||||
} else if isConnect {
|
||||
if path != "" || authority == "" { // normal CONNECT
|
||||
return nil, errors.New(":path must be empty and :authority must not be empty")
|
||||
}
|
||||
} else if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
|
||||
return nil, errors.New(":path, :authority and :method must not be empty")
|
||||
}
|
||||
|
||||
var u *url.URL
|
||||
var requestURI string
|
||||
var err error
|
||||
|
||||
if isConnect {
|
||||
u = &url.URL{}
|
||||
if isExtendedConnected {
|
||||
u, err = url.ParseRequestURI(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
u.Path = path
|
||||
}
|
||||
u.Scheme = scheme
|
||||
u.Host = authority
|
||||
requestURI = authority
|
||||
} else {
|
||||
protocol = "HTTP/3.0"
|
||||
u, err = url.ParseRequestURI(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestURI = path
|
||||
}
|
||||
|
||||
var contentLength int64
|
||||
if len(contentLengthStr) > 0 {
|
||||
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Method: method,
|
||||
URL: u,
|
||||
Proto: protocol,
|
||||
ProtoMajor: 3,
|
||||
ProtoMinor: 0,
|
||||
Header: httpHeaders,
|
||||
Body: nil,
|
||||
ContentLength: contentLength,
|
||||
Host: authority,
|
||||
RequestURI: requestURI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func hostnameFromRequest(req *http.Request) string {
|
||||
if req.URL != nil {
|
||||
return req.URL.Host
|
||||
}
|
||||
return ""
|
||||
}
|
283
vendor/github.com/quic-go/quic-go/http3/request_writer.go
generated
vendored
Normal file
283
vendor/github.com/quic-go/quic-go/http3/request_writer.go
generated
vendored
Normal file
|
@ -0,0 +1,283 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/idna"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const bodyCopyBufferSize = 8 * 1024
|
||||
|
||||
type requestWriter struct {
|
||||
mutex sync.Mutex
|
||||
encoder *qpack.Encoder
|
||||
headerBuf *bytes.Buffer
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newRequestWriter(logger utils.Logger) *requestWriter {
|
||||
headerBuf := &bytes.Buffer{}
|
||||
encoder := qpack.NewEncoder(headerBuf)
|
||||
return &requestWriter{
|
||||
encoder: encoder,
|
||||
headerBuf: headerBuf,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error {
|
||||
// TODO: figure out how to add support for trailers
|
||||
buf := &bytes.Buffer{}
|
||||
if err := w.writeHeaders(buf, req, gzip); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := str.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
defer w.encoder.Close()
|
||||
defer w.headerBuf.Reset()
|
||||
|
||||
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b := make([]byte, 0, 128)
|
||||
b = (&headersFrame{Length: uint64(w.headerBuf.Len())}).Append(b)
|
||||
if _, err := wr.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := wr.Write(w.headerBuf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
// Modified to support Extended CONNECT:
|
||||
// Contrary to what the godoc for the http.Request says,
|
||||
// we do respect the Proto field if the method is CONNECT.
|
||||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
host, err := httpguts.PunycodeHostPort(host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// http.NewRequest sets this field to HTTP/1.1
|
||||
isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1"
|
||||
|
||||
var path string
|
||||
if req.Method != http.MethodConnect || isExtendedConnect {
|
||||
path = req.URL.RequestURI()
|
||||
if !validPseudoPath(path) {
|
||||
orig := path
|
||||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
||||
if !validPseudoPath(path) {
|
||||
if req.URL.Opaque != "" {
|
||||
return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||||
} else {
|
||||
return fmt.Errorf("invalid request :path %q", orig)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for any invalid headers and return an error before we
|
||||
// potentially pollute our hpack state. (We want to be able to
|
||||
// continue to reuse the hpack encoder for future requests)
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return fmt.Errorf("invalid HTTP header name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enumerateHeaders := func(f func(name, value string)) {
|
||||
// 8.1.2.3 Request Pseudo-Header Fields
|
||||
// The :path pseudo-header field includes the path and query parts of the
|
||||
// target URI (the path-absolute production and optionally a '?' character
|
||||
// followed by the query production (see Sections 3.3 and 3.4 of
|
||||
// [RFC3986]).
|
||||
f(":authority", host)
|
||||
f(":method", req.Method)
|
||||
if req.Method != http.MethodConnect || isExtendedConnect {
|
||||
f(":path", path)
|
||||
f(":scheme", req.URL.Scheme)
|
||||
}
|
||||
if isExtendedConnect {
|
||||
f(":protocol", req.Proto)
|
||||
}
|
||||
if trailers != "" {
|
||||
f("trailer", trailers)
|
||||
}
|
||||
|
||||
var didUA bool
|
||||
for k, vv := range req.Header {
|
||||
if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
|
||||
// Host is :authority, already sent.
|
||||
// Content-Length is automatic, set below.
|
||||
continue
|
||||
} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
|
||||
strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
|
||||
strings.EqualFold(k, "keep-alive") {
|
||||
// Per 8.1.2.2 Connection-Specific Header
|
||||
// Fields, don't send connection-specific
|
||||
// fields. We have already checked if any
|
||||
// are error-worthy so just ignore the rest.
|
||||
continue
|
||||
} else if strings.EqualFold(k, "user-agent") {
|
||||
// Match Go's http1 behavior: at most one
|
||||
// User-Agent. If set to nil or empty string,
|
||||
// then omit it. Otherwise if not mentioned,
|
||||
// include the default (below).
|
||||
didUA = true
|
||||
if len(vv) < 1 {
|
||||
continue
|
||||
}
|
||||
vv = vv[:1]
|
||||
if vv[0] == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, v := range vv {
|
||||
f(k, v)
|
||||
}
|
||||
}
|
||||
if shouldSendReqContentLength(req.Method, contentLength) {
|
||||
f("content-length", strconv.FormatInt(contentLength, 10))
|
||||
}
|
||||
if addGzipHeader {
|
||||
f("accept-encoding", "gzip")
|
||||
}
|
||||
if !didUA {
|
||||
f("user-agent", defaultUserAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// Do a first pass over the headers counting bytes to ensure
|
||||
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
|
||||
// separate pass before encoding the headers to prevent
|
||||
// modifying the hpack state.
|
||||
hlSize := uint64(0)
|
||||
enumerateHeaders(func(name, value string) {
|
||||
hf := hpack.HeaderField{Name: name, Value: value}
|
||||
hlSize += uint64(hf.Size())
|
||||
})
|
||||
|
||||
// TODO: check maximum header list size
|
||||
// if hlSize > cc.peerMaxHeaderListSize {
|
||||
// return errRequestHeaderListSize
|
||||
// }
|
||||
|
||||
// trace := httptrace.ContextClientTrace(req.Context())
|
||||
// traceHeaders := traceHasWroteHeaderField(trace)
|
||||
|
||||
// Header list size is ok. Write the headers.
|
||||
enumerateHeaders(func(name, value string) {
|
||||
name = strings.ToLower(name)
|
||||
w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
|
||||
// if traceHeaders {
|
||||
// traceWroteHeaderField(trace, name, value)
|
||||
// }
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||
// and returns a host:port. The port 443 is added if needed.
|
||||
func authorityAddr(scheme string, authority string) (addr string) {
|
||||
host, port, err := net.SplitHostPort(authority)
|
||||
if err != nil { // authority didn't have a port
|
||||
port = "443"
|
||||
if scheme == "http" {
|
||||
port = "80"
|
||||
}
|
||||
host = authority
|
||||
}
|
||||
if a, err := idna.ToASCII(host); err == nil {
|
||||
host = a
|
||||
}
|
||||
// IPv6 address literal, without a port:
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
return host + ":" + port
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// validPseudoPath reports whether v is a valid :path pseudo-header
|
||||
// value. It must be either:
|
||||
//
|
||||
// *) a non-empty string starting with '/'
|
||||
// *) the string '*', for OPTIONS requests.
|
||||
//
|
||||
// For now this is only used a quick check for deciding when to clean
|
||||
// up Opaque URLs before sending requests from the Transport.
|
||||
// See golang.org/issue/16847
|
||||
//
|
||||
// We used to enforce that the path also didn't start with "//", but
|
||||
// Google's GFE accepts such paths and Chrome sends them, so ignore
|
||||
// that part of the spec. See golang.org/issue/19103.
|
||||
func validPseudoPath(v string) bool {
|
||||
return (len(v) > 0 && v[0] == '/') || v == "*"
|
||||
}
|
||||
|
||||
// actualContentLength returns a sanitized version of
|
||||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||||
// means unknown.
|
||||
func actualContentLength(req *http.Request) int64 {
|
||||
if req.Body == nil {
|
||||
return 0
|
||||
}
|
||||
if req.ContentLength != 0 {
|
||||
return req.ContentLength
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||||
// a "content-length" request header. This logic is basically a copy of the net/http
|
||||
// transferWriter.shouldSendContentLength.
|
||||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||||
// -1 means unknown.
|
||||
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
||||
if contentLength > 0 {
|
||||
return true
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return false
|
||||
}
|
||||
// For zero bodies, whether we send a content-length depends on the method.
|
||||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||||
switch method {
|
||||
case "POST", "PUT", "PATCH":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
121
vendor/github.com/quic-go/quic-go/http3/response_writer.go
generated
vendored
Normal file
121
vendor/github.com/quic-go/quic-go/http3/response_writer.go
generated
vendored
Normal file
|
@ -0,0 +1,121 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
type responseWriter struct {
|
||||
conn quic.Connection
|
||||
bufferedStr *bufio.Writer
|
||||
buf []byte
|
||||
|
||||
header http.Header
|
||||
status int // status code passed to WriteHeader
|
||||
headerWritten bool
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.ResponseWriter = &responseWriter{}
|
||||
_ http.Flusher = &responseWriter{}
|
||||
_ Hijacker = &responseWriter{}
|
||||
)
|
||||
|
||||
func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
|
||||
return &responseWriter{
|
||||
header: http.Header{},
|
||||
buf: make([]byte, 16),
|
||||
conn: conn,
|
||||
bufferedStr: bufio.NewWriter(str),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(status int) {
|
||||
if w.headerWritten {
|
||||
return
|
||||
}
|
||||
|
||||
if status < 100 || status >= 200 {
|
||||
w.headerWritten = true
|
||||
}
|
||||
w.status = status
|
||||
|
||||
var headers bytes.Buffer
|
||||
enc := qpack.NewEncoder(&headers)
|
||||
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
|
||||
|
||||
for k, v := range w.header {
|
||||
for index := range v {
|
||||
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
|
||||
}
|
||||
}
|
||||
|
||||
w.buf = w.buf[:0]
|
||||
w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf)
|
||||
w.logger.Infof("Responding with %d", status)
|
||||
if _, err := w.bufferedStr.Write(w.buf); err != nil {
|
||||
w.logger.Errorf("could not write headers frame: %s", err.Error())
|
||||
}
|
||||
if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil {
|
||||
w.logger.Errorf("could not write header frame payload: %s", err.Error())
|
||||
}
|
||||
if !w.headerWritten {
|
||||
w.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if !bodyAllowedForStatus(w.status) {
|
||||
return 0, http.ErrBodyNotAllowed
|
||||
}
|
||||
df := &dataFrame{Length: uint64(len(p))}
|
||||
w.buf = w.buf[:0]
|
||||
w.buf = df.Append(w.buf)
|
||||
if _, err := w.bufferedStr.Write(w.buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return w.bufferedStr.Write(p)
|
||||
}
|
||||
|
||||
func (w *responseWriter) Flush() {
|
||||
if err := w.bufferedStr.Flush(); err != nil {
|
||||
w.logger.Errorf("could not flush to stream: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) StreamCreator() StreamCreator {
|
||||
return w.conn
|
||||
}
|
||||
|
||||
// copied from http2/http2.go
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == http.StatusNoContent:
|
||||
return false
|
||||
case status == http.StatusNotModified:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
247
vendor/github.com/quic-go/quic-go/http3/roundtrip.go
generated
vendored
Normal file
247
vendor/github.com/quic-go/quic-go/http3/roundtrip.go
generated
vendored
Normal file
|
@ -0,0 +1,247 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type roundTripCloser interface {
|
||||
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
|
||||
HandshakeComplete() bool
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// RoundTripper implements the http.RoundTripper interface
|
||||
type RoundTripper struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
// DisableCompression, if true, prevents the Transport from
|
||||
// requesting compression with an "Accept-Encoding: gzip"
|
||||
// request header when the Request contains no existing
|
||||
// Accept-Encoding value. If the Transport requests gzip on
|
||||
// its own and gets a gzipped response, it's transparently
|
||||
// decoded in the Response.Body. However, if the user
|
||||
// explicitly requested gzip it is not automatically
|
||||
// uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// QuicConfig is the quic.Config used for dialing new connections.
|
||||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Enable support for HTTP/3 datagrams.
|
||||
// If set to true, QuicConfig.EnableDatagram will be set.
|
||||
// See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html.
|
||||
EnableDatagrams bool
|
||||
|
||||
// Additional HTTP/3 settings.
|
||||
// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
|
||||
AdditionalSettings map[uint64]uint64
|
||||
|
||||
// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
|
||||
// It is called right after parsing the frame type.
|
||||
// If parsing the frame type fails, the error is passed to the callback.
|
||||
// In that case, the frame type will not be set.
|
||||
// Callers can either ignore the frame and return control of the stream back to HTTP/3
|
||||
// (by returning hijacked false).
|
||||
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||
|
||||
// When set, this callback is called for unknown unidirectional stream of unknown stream type.
|
||||
// If parsing the stream type fails, the error is passed to the callback.
|
||||
// In that case, the stream type will not be set.
|
||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||
|
||||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, quic.DialAddrEarlyContext will be used.
|
||||
Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
|
||||
|
||||
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||
// allowed in the server's response header.
|
||||
// Zero means to use a default limit.
|
||||
MaxResponseHeaderBytes int64
|
||||
|
||||
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
|
||||
clients map[string]roundTripCloser
|
||||
}
|
||||
|
||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
type RoundTripOpt struct {
|
||||
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
|
||||
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
|
||||
OnlyCachedConn bool
|
||||
// DontCloseRequestStream controls whether the request stream is closed after sending the request.
|
||||
// If set, context cancellations have no effect after the response headers are received.
|
||||
DontCloseRequestStream bool
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.RoundTripper = &RoundTripper{}
|
||||
_ io.Closer = &RoundTripper{}
|
||||
)
|
||||
|
||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
||||
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||
|
||||
// RoundTripOpt is like RoundTrip, but takes options.
|
||||
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if req.URL == nil {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("http3: nil Request.URL")
|
||||
}
|
||||
if req.URL.Host == "" {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("http3: no Host in request URL")
|
||||
}
|
||||
if req.Header == nil {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("http3: nil Request.Header")
|
||||
}
|
||||
if req.URL.Scheme != "https" {
|
||||
closeRequestBody(req)
|
||||
return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
|
||||
}
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("http3: invalid http header field name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.Method != "" && !validMethod(req.Method) {
|
||||
closeRequestBody(req)
|
||||
return nil, fmt.Errorf("http3: invalid method %q", req.Method)
|
||||
}
|
||||
|
||||
hostname := authorityAddr("https", hostnameFromRequest(req))
|
||||
cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rsp, err := cl.RoundTripOpt(req, opt)
|
||||
if err != nil {
|
||||
r.removeClient(hostname)
|
||||
if isReused {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
return r.RoundTripOpt(req, opt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rsp, err
|
||||
}
|
||||
|
||||
// RoundTrip does a round trip.
|
||||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if r.clients == nil {
|
||||
r.clients = make(map[string]roundTripCloser)
|
||||
}
|
||||
|
||||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
if onlyCached {
|
||||
return nil, false, ErrNoCachedConn
|
||||
}
|
||||
var err error
|
||||
newCl := newClient
|
||||
if r.newClient != nil {
|
||||
newCl = r.newClient
|
||||
}
|
||||
client, err = newCl(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{
|
||||
EnableDatagram: r.EnableDatagrams,
|
||||
DisableCompression: r.DisableCompression,
|
||||
MaxHeaderBytes: r.MaxResponseHeaderBytes,
|
||||
StreamHijacker: r.StreamHijacker,
|
||||
UniStreamHijacker: r.UniStreamHijacker,
|
||||
},
|
||||
r.QuicConfig,
|
||||
r.Dial,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
r.clients[hostname] = client
|
||||
} else if client.HandshakeComplete() {
|
||||
isReused = true
|
||||
}
|
||||
return client, isReused, nil
|
||||
}
|
||||
|
||||
func (r *RoundTripper) removeClient(hostname string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
if r.clients == nil {
|
||||
return
|
||||
}
|
||||
delete(r.clients, hostname)
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this RoundTripper has used
|
||||
func (r *RoundTripper) Close() error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for _, client := range r.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.clients = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeRequestBody(req *http.Request) {
|
||||
if req.Body != nil {
|
||||
req.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
/*
|
||||
Method = "OPTIONS" ; Section 9.2
|
||||
| "GET" ; Section 9.3
|
||||
| "HEAD" ; Section 9.4
|
||||
| "POST" ; Section 9.5
|
||||
| "PUT" ; Section 9.6
|
||||
| "DELETE" ; Section 9.7
|
||||
| "TRACE" ; Section 9.8
|
||||
| "CONNECT" ; Section 9.9
|
||||
| extension-method
|
||||
extension-method = token
|
||||
token = 1*<any CHAR except CTLs or separators>
|
||||
*/
|
||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||
}
|
||||
|
||||
// copied from net/http/http.go
|
||||
func isNotToken(r rune) bool {
|
||||
return !httpguts.IsTokenRune(r)
|
||||
}
|
752
vendor/github.com/quic-go/quic-go/http3/server.go
generated
vendored
Normal file
752
vendor/github.com/quic-go/quic-go/http3/server.go
generated
vendored
Normal file
|
@ -0,0 +1,752 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
// allows mocking of quic.Listen and quic.ListenAddr
|
||||
var (
|
||||
quicListen = quic.ListenEarly
|
||||
quicListenAddr = quic.ListenAddrEarly
|
||||
)
|
||||
|
||||
const (
|
||||
// NextProtoH3Draft29 is the ALPN protocol negotiated during the TLS handshake, for QUIC draft 29.
|
||||
NextProtoH3Draft29 = "h3-29"
|
||||
// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2.
|
||||
NextProtoH3 = "h3"
|
||||
)
|
||||
|
||||
// StreamType is the stream type of a unidirectional stream.
|
||||
type StreamType uint64
|
||||
|
||||
const (
|
||||
streamTypeControlStream = 0
|
||||
streamTypePushStream = 1
|
||||
streamTypeQPACKEncoderStream = 2
|
||||
streamTypeQPACKDecoderStream = 3
|
||||
)
|
||||
|
||||
func versionToALPN(v protocol.VersionNumber) string {
|
||||
if v == protocol.Version1 || v == protocol.Version2 {
|
||||
return NextProtoH3
|
||||
}
|
||||
if v == protocol.VersionTLS || v == protocol.VersionDraft29 {
|
||||
return NextProtoH3Draft29
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConfigureTLSConfig creates a new tls.Config which can be used
|
||||
// to create a quic.Listener meant for serving http3. The created
|
||||
// tls.Config adds the functionality of detecting the used QUIC version
|
||||
// in order to set the correct ALPN value for the http3 connection.
|
||||
func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
|
||||
// The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set.
|
||||
// That way, we can get the QUIC version and set the correct ALPN value.
|
||||
return &tls.Config{
|
||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
// determine the ALPN from the QUIC version used
|
||||
proto := NextProtoH3
|
||||
if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok {
|
||||
proto = versionToALPN(qconn.GetQUICVersion())
|
||||
}
|
||||
config := tlsConf
|
||||
if tlsConf.GetConfigForClient != nil {
|
||||
getConfigForClient := tlsConf.GetConfigForClient
|
||||
var err error
|
||||
conf, err := getConfigForClient(ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conf != nil {
|
||||
config = conf
|
||||
}
|
||||
}
|
||||
if config == nil {
|
||||
return nil, nil
|
||||
}
|
||||
config = config.Clone()
|
||||
config.NextProtos = []string{proto}
|
||||
return config, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// contextKey is a value for use with context.WithValue. It's used as
|
||||
// a pointer so it fits in an interface{} without allocation.
|
||||
type contextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name }
|
||||
|
||||
// ServerContextKey is a context key. It can be used in HTTP
|
||||
// handlers with Context.Value to access the server that
|
||||
// started the handler. The associated value will be of
|
||||
// type *http3.Server.
|
||||
var ServerContextKey = &contextKey{"http3-server"}
|
||||
|
||||
type requestError struct {
|
||||
err error
|
||||
streamErr errorCode
|
||||
connErr errorCode
|
||||
}
|
||||
|
||||
func newStreamError(code errorCode, err error) requestError {
|
||||
return requestError{err: err, streamErr: code}
|
||||
}
|
||||
|
||||
func newConnError(code errorCode, err error) requestError {
|
||||
return requestError{err: err, connErr: code}
|
||||
}
|
||||
|
||||
// listenerInfo contains info about specific listener added with addListener
|
||||
type listenerInfo struct {
|
||||
port int // 0 means that no info about port is available
|
||||
}
|
||||
|
||||
// Server is a HTTP/3 server.
|
||||
type Server struct {
|
||||
// Addr optionally specifies the UDP address for the server to listen on,
|
||||
// in the form "host:port".
|
||||
//
|
||||
// When used by ListenAndServe and ListenAndServeTLS methods, if empty,
|
||||
// ":https" (port 443) is used. See net.Dial for details of the address
|
||||
// format.
|
||||
//
|
||||
// Otherwise, if Port is not set and underlying QUIC listeners do not
|
||||
// have valid port numbers, the port part is used in Alt-Svc headers set
|
||||
// with SetQuicHeaders.
|
||||
Addr string
|
||||
|
||||
// Port is used in Alt-Svc response headers set with SetQuicHeaders. If
|
||||
// needed Port can be manually set when the Server is created.
|
||||
//
|
||||
// This is useful when a Layer 4 firewall is redirecting UDP traffic and
|
||||
// clients must use a port different from the port the Server is
|
||||
// listening on.
|
||||
Port int
|
||||
|
||||
// TLSConfig provides a TLS configuration for use by server. It must be
|
||||
// set for ListenAndServe and Serve methods.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// QuicConfig provides the parameters for QUIC connection created with
|
||||
// Serve. If nil, it uses reasonable default values.
|
||||
//
|
||||
// Configured versions are also used in Alt-Svc response header set with
|
||||
// SetQuicHeaders.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Handler is the HTTP request handler to use. If not set, defaults to
|
||||
// http.NotFound.
|
||||
Handler http.Handler
|
||||
|
||||
// EnableDatagrams enables support for HTTP/3 datagrams.
|
||||
// If set to true, QuicConfig.EnableDatagram will be set.
|
||||
// See https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-07.
|
||||
EnableDatagrams bool
|
||||
|
||||
// MaxHeaderBytes controls the maximum number of bytes the server will
|
||||
// read parsing the request HEADERS frame. It does not limit the size of
|
||||
// the request body. If zero or negative, http.DefaultMaxHeaderBytes is
|
||||
// used.
|
||||
MaxHeaderBytes int
|
||||
|
||||
// AdditionalSettings specifies additional HTTP/3 settings.
|
||||
// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
|
||||
AdditionalSettings map[uint64]uint64
|
||||
|
||||
// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
|
||||
// It is called right after parsing the frame type.
|
||||
// If parsing the frame type fails, the error is passed to the callback.
|
||||
// In that case, the frame type will not be set.
|
||||
// Callers can either ignore the frame and return control of the stream back to HTTP/3
|
||||
// (by returning hijacked false).
|
||||
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||
|
||||
// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
|
||||
// If parsing the stream type fails, the error is passed to the callback.
|
||||
// In that case, the stream type will not be set.
|
||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||
|
||||
mutex sync.RWMutex
|
||||
listeners map[*quic.EarlyListener]listenerInfo
|
||||
|
||||
closed bool
|
||||
|
||||
altSvcHeader string
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
|
||||
//
|
||||
// If s.Addr is blank, ":https" is used.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
return s.serveConn(s.TLSConfig, nil)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
|
||||
//
|
||||
// If s.Addr is blank, ":https" is used.
|
||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
return s.serveConn(config, nil)
|
||||
}
|
||||
|
||||
// Serve an existing UDP connection.
|
||||
// It is possible to reuse the same connection for outgoing connections.
|
||||
// Closing the server does not close the connection.
|
||||
func (s *Server) Serve(conn net.PacketConn) error {
|
||||
return s.serveConn(s.TLSConfig, conn)
|
||||
}
|
||||
|
||||
// ServeQUICConn serves a single QUIC connection.
|
||||
func (s *Server) ServeQUICConn(conn quic.Connection) error {
|
||||
s.mutex.Lock()
|
||||
if s.logger == nil {
|
||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return s.handleConn(conn)
|
||||
}
|
||||
|
||||
// ServeListener serves an existing QUIC listener.
|
||||
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
|
||||
// and use it to construct a http3-friendly QUIC listener.
|
||||
// Closing the server does close the listener.
|
||||
func (s *Server) ServeListener(ln quic.EarlyListener) error {
|
||||
if err := s.addListener(&ln); err != nil {
|
||||
return err
|
||||
}
|
||||
err := s.serveListener(ln)
|
||||
s.removeListener(&ln)
|
||||
return err
|
||||
}
|
||||
|
||||
var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig")
|
||||
|
||||
func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
|
||||
if tlsConf == nil {
|
||||
return errServerWithoutTLSConfig
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
closed := s.closed
|
||||
s.mutex.Unlock()
|
||||
if closed {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
|
||||
baseConf := ConfigureTLSConfig(tlsConf)
|
||||
quicConf := s.QuicConfig
|
||||
if quicConf == nil {
|
||||
quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }}
|
||||
} else {
|
||||
quicConf = s.QuicConfig.Clone()
|
||||
}
|
||||
if s.EnableDatagrams {
|
||||
quicConf.EnableDatagrams = true
|
||||
}
|
||||
|
||||
var ln quic.EarlyListener
|
||||
var err error
|
||||
if conn == nil {
|
||||
addr := s.Addr
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
}
|
||||
ln, err = quicListenAddr(addr, baseConf, quicConf)
|
||||
} else {
|
||||
ln, err = quicListen(conn, baseConf, quicConf)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.addListener(&ln); err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.serveListener(ln)
|
||||
s.removeListener(&ln)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) serveListener(ln quic.EarlyListener) error {
|
||||
for {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
if err := s.handleConn(conn); err != nil {
|
||||
s.logger.Debugf(err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func extractPort(addr string) (int, error) {
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
portInt, err := net.LookupPort("tcp", portStr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return portInt, nil
|
||||
}
|
||||
|
||||
func (s *Server) generateAltSvcHeader() {
|
||||
if len(s.listeners) == 0 {
|
||||
// Don't announce any ports since no one is listening for connections
|
||||
s.altSvcHeader = ""
|
||||
return
|
||||
}
|
||||
|
||||
// This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed.
|
||||
supportedVersions := protocol.SupportedVersions
|
||||
if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 {
|
||||
supportedVersions = s.QuicConfig.Versions
|
||||
}
|
||||
|
||||
// keep track of which have been seen so we don't yield duplicate values
|
||||
seen := make(map[string]struct{}, len(supportedVersions))
|
||||
var versionStrings []string
|
||||
for _, version := range supportedVersions {
|
||||
if v := versionToALPN(version); len(v) > 0 {
|
||||
if _, ok := seen[v]; !ok {
|
||||
versionStrings = append(versionStrings, v)
|
||||
seen[v] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var altSvc []string
|
||||
addPort := func(port int) {
|
||||
for _, v := range versionStrings {
|
||||
altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port))
|
||||
}
|
||||
}
|
||||
|
||||
if s.Port != 0 {
|
||||
// if Port is specified, we must use it instead of the
|
||||
// listener addresses since there's a reason it's specified.
|
||||
addPort(s.Port)
|
||||
} else {
|
||||
// if we have some listeners assigned, try to find ports
|
||||
// which we can announce, otherwise nothing should be announced
|
||||
validPortsFound := false
|
||||
for _, info := range s.listeners {
|
||||
if info.port != 0 {
|
||||
addPort(info.port)
|
||||
validPortsFound = true
|
||||
}
|
||||
}
|
||||
if !validPortsFound {
|
||||
if port, err := extractPort(s.Addr); err == nil {
|
||||
addPort(port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.altSvcHeader = strings.Join(altSvc, ",")
|
||||
}
|
||||
|
||||
// We store a pointer to interface in the map set. This is safe because we only
|
||||
// call trackListener via Serve and can track+defer untrack the same pointer to
|
||||
// local variable there. We never need to compare a Listener from another caller.
|
||||
func (s *Server) addListener(l *quic.EarlyListener) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
if s.logger == nil {
|
||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
||||
}
|
||||
if s.listeners == nil {
|
||||
s.listeners = make(map[*quic.EarlyListener]listenerInfo)
|
||||
}
|
||||
|
||||
if port, err := extractPort((*l).Addr().String()); err == nil {
|
||||
s.listeners[l] = listenerInfo{port}
|
||||
} else {
|
||||
s.logger.Errorf(
|
||||
"Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err)
|
||||
s.listeners[l] = listenerInfo{}
|
||||
}
|
||||
s.generateAltSvcHeader()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) removeListener(l *quic.EarlyListener) {
|
||||
s.mutex.Lock()
|
||||
delete(s.listeners, l)
|
||||
s.generateAltSvcHeader()
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(conn quic.Connection) error {
|
||||
decoder := qpack.NewDecoder(nil)
|
||||
|
||||
// send a SETTINGS frame
|
||||
str, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening the control stream failed: %w", err)
|
||||
}
|
||||
b := make([]byte, 0, 64)
|
||||
b = quicvarint.Append(b, streamTypeControlStream) // stream type
|
||||
b = (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Append(b)
|
||||
str.Write(b)
|
||||
|
||||
go s.handleUnidirectionalStreams(conn)
|
||||
|
||||
// Process all requests immediately.
|
||||
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
|
||||
for {
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
var appErr *quic.ApplicationError
|
||||
if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(errorNoError) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("accepting stream failed: %w", err)
|
||||
}
|
||||
go func() {
|
||||
rerr := s.handleRequest(conn, str, decoder, func() {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "")
|
||||
})
|
||||
if rerr.err == errHijacked {
|
||||
return
|
||||
}
|
||||
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
|
||||
s.logger.Debugf("Handling request failed: %s", err)
|
||||
if rerr.streamErr != 0 {
|
||||
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
|
||||
}
|
||||
if rerr.connErr != 0 {
|
||||
var reason string
|
||||
if rerr.err != nil {
|
||||
reason = rerr.err.Error()
|
||||
}
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
|
||||
}
|
||||
return
|
||||
}
|
||||
str.Close()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleUnidirectionalStreams(conn quic.Connection) {
|
||||
for {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
s.logger.Debugf("accepting unidirectional stream failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func(str quic.ReceiveStream) {
|
||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||
if err != nil {
|
||||
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) {
|
||||
return
|
||||
}
|
||||
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||
return
|
||||
}
|
||||
// We're only interested in the control stream here.
|
||||
switch streamType {
|
||||
case streamTypeControlStream:
|
||||
case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
|
||||
// Our QPACK implementation doesn't use the dynamic table yet.
|
||||
// TODO: check that only one stream of each type is opened.
|
||||
return
|
||||
case streamTypePushStream: // only the server can push
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
|
||||
return
|
||||
default:
|
||||
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
|
||||
return
|
||||
}
|
||||
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||
return
|
||||
}
|
||||
f, err := parseNextFrame(str, nil)
|
||||
if err != nil {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
|
||||
return
|
||||
}
|
||||
sf, ok := f.(*settingsFrame)
|
||||
if !ok {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
|
||||
return
|
||||
}
|
||||
if !sf.Datagram {
|
||||
return
|
||||
}
|
||||
// If datagram support was enabled on our side as well as on the client side,
|
||||
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
|
||||
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
|
||||
if s.EnableDatagrams && !conn.ConnectionState().SupportsDatagrams {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
|
||||
}
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) maxHeaderBytes() uint64 {
|
||||
if s.MaxHeaderBytes <= 0 {
|
||||
return http.DefaultMaxHeaderBytes
|
||||
}
|
||||
return uint64(s.MaxHeaderBytes)
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
|
||||
var ufh unknownFrameHandlerFunc
|
||||
if s.StreamHijacker != nil {
|
||||
ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) }
|
||||
}
|
||||
frame, err := parseNextFrame(str, ufh)
|
||||
if err != nil {
|
||||
if err == errHijacked {
|
||||
return requestError{err: errHijacked}
|
||||
}
|
||||
return newStreamError(errorRequestIncomplete, err)
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
return newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
|
||||
}
|
||||
if hf.Length > s.maxHeaderBytes() {
|
||||
return newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
|
||||
}
|
||||
headerBlock := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||
return newStreamError(errorRequestIncomplete, err)
|
||||
}
|
||||
hfs, err := decoder.DecodeFull(headerBlock)
|
||||
if err != nil {
|
||||
// TODO: use the right error code
|
||||
return newConnError(errorGeneralProtocolError, err)
|
||||
}
|
||||
req, err := requestFromHeaders(hfs)
|
||||
if err != nil {
|
||||
// TODO: use the right error code
|
||||
return newStreamError(errorGeneralProtocolError, err)
|
||||
}
|
||||
|
||||
connState := conn.ConnectionState().TLS.ConnectionState
|
||||
req.TLS = &connState
|
||||
req.RemoteAddr = conn.RemoteAddr().String()
|
||||
body := newRequestBody(newStream(str, onFrameError))
|
||||
req.Body = body
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
|
||||
} else {
|
||||
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
||||
}
|
||||
|
||||
ctx := str.Context()
|
||||
ctx = context.WithValue(ctx, ServerContextKey, s)
|
||||
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
|
||||
req = req.WithContext(ctx)
|
||||
r := newResponseWriter(str, conn, s.logger)
|
||||
defer r.Flush()
|
||||
handler := s.Handler
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
|
||||
var panicked bool
|
||||
func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicked = true
|
||||
if p == http.ErrAbortHandler {
|
||||
return
|
||||
}
|
||||
// Copied from net/http/server.go
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
|
||||
}
|
||||
}()
|
||||
handler.ServeHTTP(r, req)
|
||||
}()
|
||||
|
||||
if body.wasStreamHijacked() {
|
||||
return requestError{err: errHijacked}
|
||||
}
|
||||
|
||||
if panicked {
|
||||
r.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
r.WriteHeader(http.StatusOK)
|
||||
}
|
||||
// If the EOF was read by the handler, CancelRead() is a no-op.
|
||||
str.CancelRead(quic.StreamErrorCode(errorNoError))
|
||||
return requestError{}
|
||||
}
|
||||
|
||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
|
||||
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
func (s *Server) Close() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.closed = true
|
||||
|
||||
var err error
|
||||
for ln := range s.listeners {
|
||||
if cerr := (*ln).Close(); cerr != nil && err == nil {
|
||||
err = cerr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
|
||||
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
func (s *Server) CloseGracefully(timeout time.Duration) error {
|
||||
// TODO: implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrNoAltSvcPort is the error returned by SetQuicHeaders when no port was found
|
||||
// for Alt-Svc to announce. This can happen if listening on a PacketConn without a port
|
||||
// (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr.
|
||||
var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr")
|
||||
|
||||
// SetQuicHeaders can be used to set the proper headers that announce that this server supports HTTP/3.
|
||||
// The values set by default advertise all of the ports the server is listening on, but can be
|
||||
// changed to a specific port by setting Server.Port before launching the serverr.
|
||||
// If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used
|
||||
// to extract the port, if specified.
|
||||
// For example, a server launched using ListenAndServe on an address with port 443 would set:
|
||||
//
|
||||
// Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
||||
func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
if s.altSvcHeader == "" {
|
||||
return ErrNoAltSvcPort
|
||||
}
|
||||
// use the map directly to avoid constant canonicalization
|
||||
// since the key is already canonicalized
|
||||
hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
||||
// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
|
||||
// used when handler is nil.
|
||||
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
server := &Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
}
|
||||
return server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the given network address for both, TLS and QUIC
|
||||
// connections in parallel. It returns if one of the two returns an error.
|
||||
// http.DefaultServeMux is used when handler is nil.
|
||||
// The correct Alt-Svc headers for QUIC are set.
|
||||
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
// Load certs
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
}
|
||||
|
||||
// Open the listeners
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
// Start the servers
|
||||
quicServer := &Server{
|
||||
TLSConfig: config,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
hErr := make(chan error)
|
||||
qErr := make(chan error)
|
||||
go func() {
|
||||
hErr <- http.ListenAndServeTLS(addr, certFile, keyFile, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
quicServer.SetQuicHeaders(w.Header())
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
}()
|
||||
go func() {
|
||||
qErr <- quicServer.Serve(udpConn)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-hErr:
|
||||
quicServer.Close()
|
||||
return err
|
||||
case err := <-qErr:
|
||||
// Cannot close the HTTP server or wait for requests to complete properly :/
|
||||
return err
|
||||
}
|
||||
}
|
363
vendor/github.com/quic-go/quic-go/interface.go
generated
vendored
Normal file
363
vendor/github.com/quic-go/quic-go/interface.go
generated
vendored
Normal file
|
@ -0,0 +1,363 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// The StreamID is the ID of a QUIC stream.
|
||||
type StreamID = protocol.StreamID
|
||||
|
||||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
const (
|
||||
// VersionDraft29 is IETF QUIC draft-29
|
||||
VersionDraft29 = protocol.VersionDraft29
|
||||
// Version1 is RFC 9000
|
||||
Version1 = protocol.Version1
|
||||
Version2 = protocol.Version2
|
||||
)
|
||||
|
||||
// A ClientToken is a token received by the client.
|
||||
// It can be used to skip address validation on future connection attempts.
|
||||
type ClientToken struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
type TokenStore interface {
|
||||
// Pop searches for a ClientToken associated with the given key.
|
||||
// Since tokens are not supposed to be reused, it must remove the token from the cache.
|
||||
// It returns nil when no token is found.
|
||||
Pop(key string) (token *ClientToken)
|
||||
|
||||
// Put adds a token to the cache with the given key. It might get called
|
||||
// multiple times in a connection.
|
||||
Put(key string, token *ClientToken)
|
||||
}
|
||||
|
||||
// Err0RTTRejected is the returned from:
|
||||
// * Open{Uni}Stream{Sync}
|
||||
// * Accept{Uni}Stream
|
||||
// * Stream.Read and Stream.Write
|
||||
// when the server rejects a 0-RTT connection attempt.
|
||||
var Err0RTTRejected = errors.New("0-RTT rejected")
|
||||
|
||||
// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection.
|
||||
// It is set on the Connection.Context() context,
|
||||
// as well as on the context passed to logging.Tracer.NewConnectionTracer.
|
||||
var ConnectionTracingKey = connTracingCtxKey{}
|
||||
|
||||
type connTracingCtxKey struct{}
|
||||
|
||||
// Stream is the interface implemented by QUIC streams
|
||||
// In addition to the errors listed on the Connection,
|
||||
// calls to stream functions can return a StreamError if the stream is canceled.
|
||||
type Stream interface {
|
||||
ReceiveStream
|
||||
SendStream
|
||||
// SetDeadline sets the read and write deadlines associated
|
||||
// with the connection. It is equivalent to calling both
|
||||
// SetReadDeadline and SetWriteDeadline.
|
||||
SetDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A ReceiveStream is a unidirectional Receive Stream.
|
||||
type ReceiveStream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Read reads data from the stream.
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
// If the connection was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
io.Reader
|
||||
// CancelRead aborts receiving on this stream.
|
||||
// It will ask the peer to stop transmitting stream data.
|
||||
// Read will unblock immediately, and future Read calls will fail.
|
||||
// When called multiple times or after reading the io.EOF it is a no-op.
|
||||
CancelRead(StreamErrorCode)
|
||||
// SetReadDeadline sets the deadline for future Read calls and
|
||||
// any currently-blocked Read call.
|
||||
// A zero value for t means Read will not time out.
|
||||
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A SendStream is a unidirectional Send Stream.
|
||||
type SendStream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Write writes data to the stream.
|
||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
// If the connection was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
io.Writer
|
||||
// Close closes the write-direction of the stream.
|
||||
// Future calls to Write are not permitted after calling Close.
|
||||
// It must not be called concurrently with Write.
|
||||
// It must not be called after calling CancelWrite.
|
||||
io.Closer
|
||||
// CancelWrite aborts sending on this stream.
|
||||
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||
// Write will unblock immediately, and future calls to Write will fail.
|
||||
// When called multiple times or after closing the stream it is a no-op.
|
||||
CancelWrite(StreamErrorCode)
|
||||
// The Context is canceled as soon as the write-side of the stream is closed.
|
||||
// This happens when Close() or CancelWrite() is called, or when the peer
|
||||
// cancels the read-side of their stream.
|
||||
Context() context.Context
|
||||
// SetWriteDeadline sets the deadline for future Write calls
|
||||
// and any currently-blocked Write call.
|
||||
// Even if write times out, it may return n > 0, indicating that
|
||||
// some data was successfully written.
|
||||
// A zero value for t means Write will not time out.
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A Connection is a QUIC connection between two peers.
|
||||
// Calls to the connection (and to streams) can return the following types of errors:
|
||||
// * ApplicationError: for errors triggered by the application running on top of QUIC
|
||||
// * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer)
|
||||
// * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error)
|
||||
// * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error)
|
||||
// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error)
|
||||
// * VersionNegotiationError: returned by the client, when there's no version overlap between the peers
|
||||
type Connection interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
// If the connection was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
AcceptStream(context.Context) (Stream, error)
|
||||
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||
// If the connection was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
AcceptUniStream(context.Context) (ReceiveStream, error)
|
||||
// OpenStream opens a new bidirectional QUIC stream.
|
||||
// There is no signaling to the peer about new streams:
|
||||
// The peer can only accept the stream after data has been sent on the stream.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// When reaching the peer's stream limit, err.Temporary() will be true.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenStream() (Stream, error)
|
||||
// OpenStreamSync opens a new bidirectional QUIC stream.
|
||||
// It blocks until a new stream can be opened.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenStreamSync(context.Context) (Stream, error)
|
||||
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// When reaching the peer's stream limit, Temporary() will be true.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenUniStream() (SendStream, error)
|
||||
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
|
||||
// It blocks until a new stream can be opened.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenUniStreamSync(context.Context) (SendStream, error)
|
||||
// LocalAddr returns the local address.
|
||||
LocalAddr() net.Addr
|
||||
// RemoteAddr returns the address of the peer.
|
||||
RemoteAddr() net.Addr
|
||||
// CloseWithError closes the connection with an error.
|
||||
// The error string will be sent to the peer.
|
||||
CloseWithError(ApplicationErrorCode, string) error
|
||||
// Context returns a context that is cancelled when the connection is closed.
|
||||
Context() context.Context
|
||||
// ConnectionState returns basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
// SendMessage sends a message as a datagram, as specified in RFC 9221.
|
||||
SendMessage([]byte) error
|
||||
// ReceiveMessage gets a message received in a datagram, as specified in RFC 9221.
|
||||
ReceiveMessage() ([]byte, error)
|
||||
}
|
||||
|
||||
// An EarlyConnection is a connection that is handshaking.
|
||||
// Data sent during the handshake is encrypted using the forward secure keys.
|
||||
// When using client certificates, the client's identity is only verified
|
||||
// after completion of the handshake.
|
||||
type EarlyConnection interface {
|
||||
Connection
|
||||
|
||||
// HandshakeComplete blocks until the handshake completes (or fails).
|
||||
// Data sent before completion of the handshake is encrypted with 1-RTT keys.
|
||||
// Note that the client's identity hasn't been verified yet.
|
||||
HandshakeComplete() context.Context
|
||||
|
||||
NextConnection() Connection
|
||||
}
|
||||
|
||||
// StatelessResetKey is a key used to derive stateless reset tokens.
|
||||
type StatelessResetKey [32]byte
|
||||
|
||||
// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000.
|
||||
// It is not able to handle QUIC Connection IDs longer than 20 bytes,
|
||||
// as they are allowed by RFC 8999.
|
||||
type ConnectionID = protocol.ConnectionID
|
||||
|
||||
// ConnectionIDFromBytes interprets b as a Connection ID. It panics if b is
|
||||
// longer than 20 bytes.
|
||||
func ConnectionIDFromBytes(b []byte) ConnectionID {
|
||||
return protocol.ParseConnectionID(b)
|
||||
}
|
||||
|
||||
// A ConnectionIDGenerator is an interface that allows clients to implement their own format
|
||||
// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets.
|
||||
//
|
||||
// Connection IDs generated by an implementation should always produce IDs of constant size.
|
||||
type ConnectionIDGenerator interface {
|
||||
// GenerateConnectionID generates a new ConnectionID.
|
||||
// Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs.
|
||||
GenerateConnectionID() (ConnectionID, error)
|
||||
|
||||
// ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of
|
||||
// this interface.
|
||||
// Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size
|
||||
// connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID.
|
||||
// 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections
|
||||
// in the presence of a connection migration environment.
|
||||
ConnectionIDLen() int
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
type Config struct {
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
Versions []VersionNumber
|
||||
// The length of the connection ID in bytes.
|
||||
// It can be 0, or any value between 4 and 18.
|
||||
// If not set, the interpretation depends on where the Config is used:
|
||||
// If used for dialing an address, a 0 byte connection ID will be used.
|
||||
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
|
||||
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
|
||||
ConnectionIDLength int
|
||||
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
|
||||
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
|
||||
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
|
||||
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
|
||||
ConnectionIDGenerator ConnectionIDGenerator
|
||||
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
||||
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
||||
// If this value is zero, the timeout is set to 5 seconds.
|
||||
HandshakeIdleTimeout time.Duration
|
||||
// MaxIdleTimeout is the maximum duration that may pass without any incoming network activity.
|
||||
// The actual value for the idle timeout is the minimum of this value and the peer's.
|
||||
// This value only applies after the handshake has completed.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
MaxIdleTimeout time.Duration
|
||||
// RequireAddressValidation determines if a QUIC Retry packet is sent.
|
||||
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
|
||||
// If not set, every client is forced to prove its remote address.
|
||||
RequireAddressValidation func(net.Addr) bool
|
||||
// MaxRetryTokenAge is the maximum age of a Retry token.
|
||||
// If not set, it defaults to 5 seconds. Only valid for a server.
|
||||
MaxRetryTokenAge time.Duration
|
||||
// MaxTokenAge is the maximum age of the token presented during the handshake,
|
||||
// for tokens that were issued on a previous connection.
|
||||
// If not set, it defaults to 24 hours. Only valid for a server.
|
||||
MaxTokenAge time.Duration
|
||||
// The TokenStore stores tokens received from the server.
|
||||
// Tokens are used to skip address validation on future connection attempts.
|
||||
// The key used to store tokens is the ServerName from the tls.Config, if set
|
||||
// otherwise the token is associated with the server's IP address.
|
||||
TokenStore TokenStore
|
||||
// InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data.
|
||||
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
|
||||
// will increase the window up to MaxStreamReceiveWindow.
|
||||
// If this value is zero, it will default to 512 KB.
|
||||
InitialStreamReceiveWindow uint64
|
||||
// MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 6 MB.
|
||||
MaxStreamReceiveWindow uint64
|
||||
// InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data.
|
||||
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
|
||||
// will increase the window up to MaxConnectionReceiveWindow.
|
||||
// If this value is zero, it will default to 512 KB.
|
||||
InitialConnectionReceiveWindow uint64
|
||||
// MaxConnectionReceiveWindow is the connection-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 15 MB.
|
||||
MaxConnectionReceiveWindow uint64
|
||||
// AllowConnectionWindowIncrease is called every time the connection flow controller attempts
|
||||
// to increase the connection flow control window.
|
||||
// If set, the caller can prevent an increase of the window. Typically, it would do so to
|
||||
// limit the memory usage.
|
||||
// To avoid deadlocks, it is not valid to call other functions on the connection or on streams
|
||||
// in this callback.
|
||||
AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool
|
||||
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||
// Values above 2^60 are invalid.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||
MaxIncomingStreams int64
|
||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||
// Values above 2^60 are invalid.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
MaxIncomingUniStreams int64
|
||||
// The StatelessResetKey is used to generate stateless reset tokens.
|
||||
// If no key is configured, sending of stateless resets is disabled.
|
||||
StatelessResetKey *StatelessResetKey
|
||||
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
|
||||
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
|
||||
// every half of MaxIdleTimeout, whichever is smaller).
|
||||
KeepAlivePeriod time.Duration
|
||||
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
|
||||
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
|
||||
// Note that if Path MTU discovery is causing issues on your system, please open a new issue
|
||||
DisablePathMTUDiscovery bool
|
||||
// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
|
||||
// This can be useful if version information is exchanged out-of-band.
|
||||
// It has no effect for a client.
|
||||
DisableVersionNegotiationPackets bool
|
||||
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
|
||||
// When set, 0-RTT is enabled. When not set, 0-RTT is disabled.
|
||||
// Only valid for the server.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
Allow0RTT func(net.Addr) bool
|
||||
// Enable QUIC datagram support (RFC 9221).
|
||||
EnableDatagrams bool
|
||||
Tracer logging.Tracer
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about a QUIC connection
|
||||
type ConnectionState struct {
|
||||
TLS handshake.ConnectionState
|
||||
SupportsDatagrams bool
|
||||
Version VersionNumber
|
||||
}
|
||||
|
||||
// A Listener for incoming QUIC connections
|
||||
type Listener interface {
|
||||
// Close the server. All active connections will be closed.
|
||||
Close() error
|
||||
// Addr returns the local network addr that the server is listening on.
|
||||
Addr() net.Addr
|
||||
// Accept returns new connections. It should be called in a loop.
|
||||
Accept(context.Context) (Connection, error)
|
||||
}
|
||||
|
||||
// An EarlyListener listens for incoming QUIC connections,
|
||||
// and returns them before the handshake completes.
|
||||
type EarlyListener interface {
|
||||
// Close the server. All active connections will be closed.
|
||||
Close() error
|
||||
// Addr returns the local network addr that the server is listening on.
|
||||
Addr() net.Addr
|
||||
// Accept returns new early connections. It should be called in a loop.
|
||||
Accept(context.Context) (EarlyConnection, error)
|
||||
}
|
20
vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go
generated
vendored
Normal file
20
vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go
generated
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
package ackhandler
|
||||
|
||||
import "github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
// IsFrameAckEliciting returns true if the frame is ack-eliciting.
|
||||
func IsFrameAckEliciting(f wire.Frame) bool {
|
||||
_, isAck := f.(*wire.AckFrame)
|
||||
_, isConnectionClose := f.(*wire.ConnectionCloseFrame)
|
||||
return !isAck && !isConnectionClose
|
||||
}
|
||||
|
||||
// HasAckElicitingFrames returns true if at least one frame is ack-eliciting.
|
||||
func HasAckElicitingFrames(fs []*Frame) bool {
|
||||
for _, f := range fs {
|
||||
if IsFrameAckEliciting(f.Frame) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
23
vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go
generated
vendored
Normal file
23
vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go
generated
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler.
|
||||
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
|
||||
// clientAddressValidated has no effect for a client.
|
||||
func NewAckHandler(
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
initialMaxDatagramSize protocol.ByteCount,
|
||||
rttStats *utils.RTTStats,
|
||||
clientAddressValidated bool,
|
||||
pers protocol.Perspective,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
) (SentPacketHandler, ReceivedPacketHandler) {
|
||||
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger)
|
||||
return sph, newReceivedPacketHandler(sph, rttStats, logger)
|
||||
}
|
29
vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go
generated
vendored
Normal file
29
vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go
generated
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type Frame struct {
|
||||
wire.Frame // nil if the frame has already been acknowledged in another packet
|
||||
OnLost func(wire.Frame)
|
||||
OnAcked func(wire.Frame)
|
||||
}
|
||||
|
||||
var framePool = sync.Pool{New: func() any { return &Frame{} }}
|
||||
|
||||
func GetFrame() *Frame {
|
||||
f := framePool.Get().(*Frame)
|
||||
f.OnLost = nil
|
||||
f.OnAcked = nil
|
||||
return f
|
||||
}
|
||||
|
||||
func putFrame(f *Frame) {
|
||||
f.Frame = nil
|
||||
f.OnLost = nil
|
||||
f.OnAcked = nil
|
||||
framePool.Put(f)
|
||||
}
|
52
vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
52
vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
|
@ -0,0 +1,52 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(packet *Packet)
|
||||
ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error)
|
||||
ReceivedBytes(protocol.ByteCount)
|
||||
DropPackets(protocol.EncryptionLevel)
|
||||
ResetForRetry() error
|
||||
SetHandshakeConfirmed()
|
||||
|
||||
// The SendMode determines if and what kind of packets can be sent.
|
||||
SendMode() SendMode
|
||||
// TimeUntilSend is the time when the next packet should be sent.
|
||||
// It is used for pacing packets.
|
||||
TimeUntilSend() time.Time
|
||||
// HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment.
|
||||
HasPacingBudget() bool
|
||||
SetMaxDatagramSize(count protocol.ByteCount)
|
||||
|
||||
// only to be called once the handshake is complete
|
||||
QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */
|
||||
|
||||
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
|
||||
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
|
||||
|
||||
GetLossDetectionTimeout() time.Time
|
||||
OnLossDetectionTimeout() error
|
||||
}
|
||||
|
||||
type sentPacketTracker interface {
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
ReceivedPacket(protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool
|
||||
ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error
|
||||
DropPackets(protocol.EncryptionLevel)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
|
||||
}
|
3
vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go
generated
vendored
Normal file
3
vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go
generated
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
package ackhandler
|
||||
|
||||
//go:generate sh -c "../../mockgen_private.sh ackhandler mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler sentPacketTracker"
|
55
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
55
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
|
@ -0,0 +1,55 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// A Packet is a packet
|
||||
type Packet struct {
|
||||
PacketNumber protocol.PacketNumber
|
||||
Frames []*Frame
|
||||
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
|
||||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
SendTime time.Time
|
||||
|
||||
IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller.
|
||||
|
||||
includedInBytesInFlight bool
|
||||
declaredLost bool
|
||||
skippedPacket bool
|
||||
}
|
||||
|
||||
func (p *Packet) outstanding() bool {
|
||||
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket
|
||||
}
|
||||
|
||||
var packetPool = sync.Pool{New: func() any { return &Packet{} }}
|
||||
|
||||
func GetPacket() *Packet {
|
||||
p := packetPool.Get().(*Packet)
|
||||
p.PacketNumber = 0
|
||||
p.Frames = nil
|
||||
p.LargestAcked = 0
|
||||
p.Length = 0
|
||||
p.EncryptionLevel = protocol.EncryptionLevel(0)
|
||||
p.SendTime = time.Time{}
|
||||
p.IsPathMTUProbePacket = false
|
||||
p.includedInBytesInFlight = false
|
||||
p.declaredLost = false
|
||||
p.skippedPacket = false
|
||||
return p
|
||||
}
|
||||
|
||||
// We currently only return Packets back into the pool when they're acknowledged (not when they're lost).
|
||||
// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool.
|
||||
func putPacket(p *Packet) {
|
||||
for _, f := range p.Frames {
|
||||
putFrame(f)
|
||||
}
|
||||
p.Frames = nil
|
||||
packetPool.Put(p)
|
||||
}
|
76
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go
generated
vendored
Normal file
76
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go
generated
vendored
Normal file
|
@ -0,0 +1,76 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type packetNumberGenerator interface {
|
||||
Peek() protocol.PacketNumber
|
||||
Pop() protocol.PacketNumber
|
||||
}
|
||||
|
||||
type sequentialPacketNumberGenerator struct {
|
||||
next protocol.PacketNumber
|
||||
}
|
||||
|
||||
var _ packetNumberGenerator = &sequentialPacketNumberGenerator{}
|
||||
|
||||
func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator {
|
||||
return &sequentialPacketNumberGenerator{next: initial}
|
||||
}
|
||||
|
||||
func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber {
|
||||
return p.next
|
||||
}
|
||||
|
||||
func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber {
|
||||
next := p.next
|
||||
p.next++
|
||||
return next
|
||||
}
|
||||
|
||||
// The skippingPacketNumberGenerator generates the packet number for the next packet
|
||||
// it randomly skips a packet number every averagePeriod packets (on average).
|
||||
// It is guaranteed to never skip two consecutive packet numbers.
|
||||
type skippingPacketNumberGenerator struct {
|
||||
period protocol.PacketNumber
|
||||
maxPeriod protocol.PacketNumber
|
||||
|
||||
next protocol.PacketNumber
|
||||
nextToSkip protocol.PacketNumber
|
||||
|
||||
rng utils.Rand
|
||||
}
|
||||
|
||||
var _ packetNumberGenerator = &skippingPacketNumberGenerator{}
|
||||
|
||||
func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator {
|
||||
g := &skippingPacketNumberGenerator{
|
||||
next: initial,
|
||||
period: initialPeriod,
|
||||
maxPeriod: maxPeriod,
|
||||
}
|
||||
g.generateNewSkip()
|
||||
return g
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber {
|
||||
return p.next
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber {
|
||||
next := p.next
|
||||
p.next++ // generate a new packet number for the next packet
|
||||
if p.next == p.nextToSkip {
|
||||
p.next++
|
||||
p.generateNewSkip()
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) generateNewSkip() {
|
||||
// make sure that there are never two consecutive packet numbers that are skipped
|
||||
p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
|
||||
p.period = utils.Min(2*p.period, p.maxPeriod)
|
||||
}
|
137
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
137
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
|
@ -0,0 +1,137 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type receivedPacketHandler struct {
|
||||
sentPackets sentPacketTracker
|
||||
|
||||
initialPackets *receivedPacketTracker
|
||||
handshakePackets *receivedPacketTracker
|
||||
appDataPackets *receivedPacketTracker
|
||||
|
||||
lowest1RTTPacket protocol.PacketNumber
|
||||
}
|
||||
|
||||
var _ ReceivedPacketHandler = &receivedPacketHandler{}
|
||||
|
||||
func newReceivedPacketHandler(
|
||||
sentPackets sentPacketTracker,
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
sentPackets: sentPackets,
|
||||
initialPackets: newReceivedPacketTracker(rttStats, logger),
|
||||
handshakePackets: newReceivedPacketTracker(rttStats, logger),
|
||||
appDataPackets: newReceivedPacketTracker(rttStats, logger),
|
||||
lowest1RTTPacket: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(
|
||||
pn protocol.PacketNumber,
|
||||
ecn protocol.ECN,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
rcvTime time.Time,
|
||||
shouldInstigateAck bool,
|
||||
) error {
|
||||
h.sentPackets.ReceivedPacket(encLevel)
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
|
||||
case protocol.EncryptionHandshake:
|
||||
return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
|
||||
case protocol.Encryption0RTT:
|
||||
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {
|
||||
return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket)
|
||||
}
|
||||
return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
|
||||
case protocol.Encryption1RTT:
|
||||
if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket {
|
||||
h.lowest1RTTPacket = pn
|
||||
}
|
||||
if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck); err != nil {
|
||||
return err
|
||||
}
|
||||
h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked())
|
||||
return nil
|
||||
default:
|
||||
panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
|
||||
//nolint:exhaustive // 1-RTT packet number space is never dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.initialPackets = nil
|
||||
case protocol.EncryptionHandshake:
|
||||
h.handshakePackets = nil
|
||||
case protocol.Encryption0RTT:
|
||||
// Nothing to do here.
|
||||
// If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted.
|
||||
default:
|
||||
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
|
||||
var initialAlarm, handshakeAlarm time.Time
|
||||
if h.initialPackets != nil {
|
||||
initialAlarm = h.initialPackets.GetAlarmTimeout()
|
||||
}
|
||||
if h.handshakePackets != nil {
|
||||
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
|
||||
}
|
||||
oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
|
||||
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
|
||||
var ack *wire.AckFrame
|
||||
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
if h.initialPackets != nil {
|
||||
ack = h.initialPackets.GetAckFrame(onlyIfQueued)
|
||||
}
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakePackets != nil {
|
||||
ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
|
||||
}
|
||||
case protocol.Encryption1RTT:
|
||||
// 0-RTT packets can't contain ACK frames
|
||||
return h.appDataPackets.GetAckFrame(onlyIfQueued)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
|
||||
// Set it to 0 in order to save bytes.
|
||||
if ack != nil {
|
||||
ack.DelayTime = 0
|
||||
}
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
if h.initialPackets != nil {
|
||||
return h.initialPackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakePackets != nil {
|
||||
return h.handshakePackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
case protocol.Encryption0RTT, protocol.Encryption1RTT:
|
||||
return h.appDataPackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
panic("unexpected encryption level")
|
||||
}
|
151
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go
generated
vendored
Normal file
151
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go
generated
vendored
Normal file
|
@ -0,0 +1,151 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// interval is an interval from one PacketNumber to the other
|
||||
type interval struct {
|
||||
Start protocol.PacketNumber
|
||||
End protocol.PacketNumber
|
||||
}
|
||||
|
||||
var intervalElementPool sync.Pool
|
||||
|
||||
func init() {
|
||||
intervalElementPool = *list.NewPool[interval]()
|
||||
}
|
||||
|
||||
// The receivedPacketHistory stores if a packet number has already been received.
|
||||
// It generates ACK ranges which can be used to assemble an ACK frame.
|
||||
// It does not store packet contents.
|
||||
type receivedPacketHistory struct {
|
||||
ranges *list.List[interval]
|
||||
|
||||
deletedBelow protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newReceivedPacketHistory() *receivedPacketHistory {
|
||||
return &receivedPacketHistory{
|
||||
ranges: list.NewWithPool[interval](&intervalElementPool),
|
||||
}
|
||||
}
|
||||
|
||||
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
|
||||
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
|
||||
// ignore delayed packets, if we already deleted the range
|
||||
if p < h.deletedBelow {
|
||||
return false
|
||||
}
|
||||
isNew := h.addToRanges(p)
|
||||
h.maybeDeleteOldRanges()
|
||||
return isNew
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
|
||||
if h.ranges.Len() == 0 {
|
||||
h.ranges.PushBack(interval{Start: p, End: p})
|
||||
return true
|
||||
}
|
||||
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
// p already included in an existing range. Nothing to do here
|
||||
if p >= el.Value.Start && p <= el.Value.End {
|
||||
return false
|
||||
}
|
||||
|
||||
if el.Value.End == p-1 { // extend a range at the end
|
||||
el.Value.End = p
|
||||
return true
|
||||
}
|
||||
if el.Value.Start == p+1 { // extend a range at the beginning
|
||||
el.Value.Start = p
|
||||
|
||||
prev := el.Prev()
|
||||
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
|
||||
prev.Value.End = el.Value.End
|
||||
h.ranges.Remove(el)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// create a new range at the end
|
||||
if p > el.Value.End {
|
||||
h.ranges.InsertAfter(interval{Start: p, End: p}, el)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// create a new range at the beginning
|
||||
h.ranges.InsertBefore(interval{Start: p, End: p}, h.ranges.Front())
|
||||
return true
|
||||
}
|
||||
|
||||
// Delete old ranges, if we're tracking more than 500 of them.
|
||||
// This is a DoS defense against a peer that sends us too many gaps.
|
||||
func (h *receivedPacketHistory) maybeDeleteOldRanges() {
|
||||
for h.ranges.Len() > protocol.MaxNumAckRanges {
|
||||
h.ranges.Remove(h.ranges.Front())
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteBelow deletes all entries below (but not including) p
|
||||
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
|
||||
if p < h.deletedBelow {
|
||||
return
|
||||
}
|
||||
h.deletedBelow = p
|
||||
|
||||
nextEl := h.ranges.Front()
|
||||
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
||||
nextEl = el.Next()
|
||||
|
||||
if el.Value.End < p { // delete a whole range
|
||||
h.ranges.Remove(el)
|
||||
} else if p > el.Value.Start && p <= el.Value.End {
|
||||
el.Value.Start = p
|
||||
return
|
||||
} else { // no ranges affected. Nothing to do
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AppendAckRanges appends to a slice of all AckRanges that can be used in an AckFrame
|
||||
func (h *receivedPacketHistory) AppendAckRanges(ackRanges []wire.AckRange) []wire.AckRange {
|
||||
if h.ranges.Len() > 0 {
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
ackRanges = append(ackRanges, wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End})
|
||||
}
|
||||
}
|
||||
return ackRanges
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
|
||||
ackRange := wire.AckRange{}
|
||||
if h.ranges.Len() > 0 {
|
||||
r := h.ranges.Back().Value
|
||||
ackRange.Smallest = r.Start
|
||||
ackRange.Largest = r.End
|
||||
}
|
||||
return ackRange
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool {
|
||||
if p < h.deletedBelow {
|
||||
return true
|
||||
}
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
if p > el.Value.End {
|
||||
return false
|
||||
}
|
||||
if p <= el.Value.End && p >= el.Value.Start {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
194
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
Normal file
194
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
Normal file
|
@ -0,0 +1,194 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// number of ack-eliciting packets received before sending an ack.
|
||||
const packetsBeforeAck = 2
|
||||
|
||||
type receivedPacketTracker struct {
|
||||
largestObserved protocol.PacketNumber
|
||||
ignoreBelow protocol.PacketNumber
|
||||
largestObservedReceivedTime time.Time
|
||||
ect0, ect1, ecnce uint64
|
||||
|
||||
packetHistory *receivedPacketHistory
|
||||
|
||||
maxAckDelay time.Duration
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
hasNewAck bool // true as soon as we received an ack-eliciting new packet
|
||||
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
|
||||
|
||||
ackElicitingPacketsReceivedSinceLastAck int
|
||||
ackAlarm time.Time
|
||||
lastAck *wire.AckFrame
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newReceivedPacketTracker(
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) *receivedPacketTracker {
|
||||
return &receivedPacketTracker{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
maxAckDelay: protocol.MaxAckDelay,
|
||||
rttStats: rttStats,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) error {
|
||||
if isNew := h.packetHistory.ReceivedPacket(packetNumber); !isNew {
|
||||
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", packetNumber)
|
||||
}
|
||||
|
||||
isMissing := h.isMissing(packetNumber)
|
||||
if packetNumber >= h.largestObserved {
|
||||
h.largestObserved = packetNumber
|
||||
h.largestObservedReceivedTime = rcvTime
|
||||
}
|
||||
|
||||
if shouldInstigateAck {
|
||||
h.hasNewAck = true
|
||||
}
|
||||
if shouldInstigateAck {
|
||||
h.maybeQueueAck(packetNumber, rcvTime, isMissing)
|
||||
}
|
||||
switch ecn {
|
||||
case protocol.ECNNon:
|
||||
case protocol.ECT0:
|
||||
h.ect0++
|
||||
case protocol.ECT1:
|
||||
h.ect1++
|
||||
case protocol.ECNCE:
|
||||
h.ecnce++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IgnoreBelow sets a lower limit for acknowledging packets.
|
||||
// Packets with packet numbers smaller than p will not be acked.
|
||||
func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) {
|
||||
if p <= h.ignoreBelow {
|
||||
return
|
||||
}
|
||||
h.ignoreBelow = p
|
||||
h.packetHistory.DeleteBelow(p)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tIgnoring all packets below %d.", p)
|
||||
}
|
||||
}
|
||||
|
||||
// isMissing says if a packet was reported missing in the last ACK.
|
||||
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
|
||||
if h.lastAck == nil || p < h.ignoreBelow {
|
||||
return false
|
||||
}
|
||||
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) hasNewMissingPackets() bool {
|
||||
if h.lastAck == nil {
|
||||
return false
|
||||
}
|
||||
highestRange := h.packetHistory.GetHighestAckRange()
|
||||
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
|
||||
}
|
||||
|
||||
// maybeQueueAck queues an ACK, if necessary.
|
||||
func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) {
|
||||
// always acknowledge the first packet
|
||||
if h.lastAck == nil {
|
||||
if !h.ackQueued {
|
||||
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||
}
|
||||
h.ackQueued = true
|
||||
return
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
return
|
||||
}
|
||||
|
||||
h.ackElicitingPacketsReceivedSinceLastAck++
|
||||
|
||||
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||
// missing packets we reported in the previous ack, send an ACK immediately.
|
||||
if wasMissing {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
|
||||
}
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
// send an ACK every 2 ack-eliciting packets
|
||||
if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
|
||||
}
|
||||
h.ackQueued = true
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
|
||||
}
|
||||
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
|
||||
}
|
||||
|
||||
// Queue an ACK if there are new missing packets to report.
|
||||
if h.hasNewMissingPackets() {
|
||||
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
// cancel the ack alarm
|
||||
h.ackAlarm = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
|
||||
if !h.hasNewAck {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
if onlyIfQueued {
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
|
||||
return nil
|
||||
}
|
||||
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
|
||||
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||
}
|
||||
}
|
||||
|
||||
ack := wire.GetAckFrame()
|
||||
ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime))
|
||||
ack.ECT0 = h.ect0
|
||||
ack.ECT1 = h.ect1
|
||||
ack.ECNCE = h.ecnce
|
||||
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
|
||||
|
||||
if h.lastAck != nil {
|
||||
wire.PutAckFrame(h.lastAck)
|
||||
}
|
||||
h.lastAck = ack
|
||||
h.ackAlarm = time.Time{}
|
||||
h.ackQueued = false
|
||||
h.hasNewAck = false
|
||||
h.ackElicitingPacketsReceivedSinceLastAck = 0
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
||||
|
||||
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
|
||||
return h.packetHistory.IsPotentiallyDuplicate(pn)
|
||||
}
|
40
vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
40
vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
|
@ -0,0 +1,40 @@
|
|||
package ackhandler
|
||||
|
||||
import "fmt"
|
||||
|
||||
// The SendMode says what kind of packets can be sent.
|
||||
type SendMode uint8
|
||||
|
||||
const (
|
||||
// SendNone means that no packets should be sent
|
||||
SendNone SendMode = iota
|
||||
// SendAck means an ACK-only packet should be sent
|
||||
SendAck
|
||||
// SendPTOInitial means that an Initial probe packet should be sent
|
||||
SendPTOInitial
|
||||
// SendPTOHandshake means that a Handshake probe packet should be sent
|
||||
SendPTOHandshake
|
||||
// SendPTOAppData means that an Application data probe packet should be sent
|
||||
SendPTOAppData
|
||||
// SendAny means that any packet should be sent
|
||||
SendAny
|
||||
)
|
||||
|
||||
func (s SendMode) String() string {
|
||||
switch s {
|
||||
case SendNone:
|
||||
return "none"
|
||||
case SendAck:
|
||||
return "ack"
|
||||
case SendPTOInitial:
|
||||
return "pto (Initial)"
|
||||
case SendPTOHandshake:
|
||||
return "pto (Handshake)"
|
||||
case SendPTOAppData:
|
||||
return "pto (Application Data)"
|
||||
case SendAny:
|
||||
return "any"
|
||||
default:
|
||||
return fmt.Sprintf("invalid send mode: %d", s)
|
||||
}
|
||||
}
|
861
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
861
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
|
@ -0,0 +1,861 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/congestion"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||
// Specified as an RTT multiplier.
|
||||
timeThreshold = 9.0 / 8
|
||||
// Maximum reordering in packets before packet threshold loss detection considers a packet lost.
|
||||
packetThreshold = 3
|
||||
// Before validating the client's address, the server won't send more than 3x bytes than it received.
|
||||
amplificationFactor = 3
|
||||
// We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet.
|
||||
minRTTAfterRetry = 5 * time.Millisecond
|
||||
// The PTO duration uses exponential backoff, but is truncated to a maximum value, as allowed by RFC 8961, section 4.4.
|
||||
maxPTODuration = 60 * time.Second
|
||||
)
|
||||
|
||||
type packetNumberSpace struct {
|
||||
history *sentPacketHistory
|
||||
pns packetNumberGenerator
|
||||
|
||||
lossTime time.Time
|
||||
lastAckElicitingPacketTime time.Time
|
||||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestSent protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace {
|
||||
var pns packetNumberGenerator
|
||||
if skipPNs {
|
||||
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
|
||||
} else {
|
||||
pns = newSequentialPacketNumberGenerator(initialPN)
|
||||
}
|
||||
return &packetNumberSpace{
|
||||
history: newSentPacketHistory(rttStats),
|
||||
pns: pns,
|
||||
largestSent: protocol.InvalidPacketNumber,
|
||||
largestAcked: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
type sentPacketHandler struct {
|
||||
initialPackets *packetNumberSpace
|
||||
handshakePackets *packetNumberSpace
|
||||
appDataPackets *packetNumberSpace
|
||||
|
||||
// Do we know that the peer completed address validation yet?
|
||||
// Always true for the server.
|
||||
peerCompletedAddressValidation bool
|
||||
bytesReceived protocol.ByteCount
|
||||
bytesSent protocol.ByteCount
|
||||
// Have we validated the peer's address yet?
|
||||
// Always true for the client.
|
||||
peerAddressValidated bool
|
||||
|
||||
handshakeConfirmed bool
|
||||
|
||||
// lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
|
||||
// example: we send an ACK for packets 90-100 with packet number 20
|
||||
// once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101
|
||||
// Only applies to the application-data packet number space.
|
||||
lowestNotConfirmedAcked protocol.PacketNumber
|
||||
|
||||
ackedPackets []*Packet // to avoid allocations in detectAndRemoveAckedPackets
|
||||
|
||||
bytesInFlight protocol.ByteCount
|
||||
|
||||
congestion congestion.SendAlgorithmWithDebugInfos
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
// The number of times a PTO has been sent without receiving an ack.
|
||||
ptoCount uint32
|
||||
ptoMode SendMode
|
||||
// The number of PTO probe packets that should be sent.
|
||||
// Only applies to the application-data packet number space.
|
||||
numProbesToSend int
|
||||
|
||||
// The alarm timeout
|
||||
alarm time.Time
|
||||
|
||||
perspective protocol.Perspective
|
||||
|
||||
tracer logging.ConnectionTracer
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
_ SentPacketHandler = &sentPacketHandler{}
|
||||
_ sentPacketTracker = &sentPacketHandler{}
|
||||
)
|
||||
|
||||
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
|
||||
// If the address was validated, the amplification limit doesn't apply. It has no effect for a client.
|
||||
func newSentPacketHandler(
|
||||
initialPN protocol.PacketNumber,
|
||||
initialMaxDatagramSize protocol.ByteCount,
|
||||
rttStats *utils.RTTStats,
|
||||
clientAddressValidated bool,
|
||||
pers protocol.Perspective,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
) *sentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
initialMaxDatagramSize,
|
||||
true, // use Reno
|
||||
tracer,
|
||||
)
|
||||
|
||||
return &sentPacketHandler{
|
||||
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
|
||||
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
|
||||
initialPackets: newPacketNumberSpace(initialPN, false, rttStats),
|
||||
handshakePackets: newPacketNumberSpace(0, false, rttStats),
|
||||
appDataPackets: newPacketNumberSpace(0, true, rttStats),
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
perspective: pers,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
|
||||
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial {
|
||||
// This function is called when the crypto setup seals a Handshake packet.
|
||||
// If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space
|
||||
// before SentPacket() was called for that Initial packet.
|
||||
return
|
||||
}
|
||||
h.dropPackets(encLevel)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) {
|
||||
if p.includedInBytesInFlight {
|
||||
if p.Length > h.bytesInFlight {
|
||||
panic("negative bytes_in_flight")
|
||||
}
|
||||
h.bytesInFlight -= p.Length
|
||||
p.includedInBytesInFlight = false
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
|
||||
// The server won't await address validation after the handshake is confirmed.
|
||||
// This applies even if we didn't receive an ACK for a Handshake packet.
|
||||
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
|
||||
h.peerCompletedAddressValidation = true
|
||||
}
|
||||
// remove outstanding packets from bytes_in_flight
|
||||
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
|
||||
h.removeFromBytesInFlight(p)
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
// drop the packet history
|
||||
//nolint:exhaustive // Not every packet number space can be dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.initialPackets = nil
|
||||
case protocol.EncryptionHandshake:
|
||||
h.handshakePackets = nil
|
||||
case protocol.Encryption0RTT:
|
||||
// This function is only called when 0-RTT is rejected,
|
||||
// and not when the client drops 0-RTT keys when the handshake completes.
|
||||
// When 0-RTT is rejected, all application data sent so far becomes invalid.
|
||||
// Delete the packets from the history and remove them from bytes_in_flight.
|
||||
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.EncryptionLevel != protocol.Encryption0RTT {
|
||||
return false, nil
|
||||
}
|
||||
h.removeFromBytesInFlight(p)
|
||||
h.appDataPackets.history.Remove(p.PacketNumber)
|
||||
return true, nil
|
||||
})
|
||||
default:
|
||||
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
|
||||
}
|
||||
if h.tracer != nil && h.ptoCount != 0 {
|
||||
h.tracer.UpdatedPTOCount(0)
|
||||
}
|
||||
h.ptoCount = 0
|
||||
h.numProbesToSend = 0
|
||||
h.ptoMode = SendNone
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
|
||||
wasAmplificationLimit := h.isAmplificationLimited()
|
||||
h.bytesReceived += n
|
||||
if wasAmplificationLimit && !h.isAmplificationLimited() {
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) {
|
||||
if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated {
|
||||
h.peerAddressValidated = true
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) packetsInFlight() int {
|
||||
packetsInFlight := h.appDataPackets.history.Len()
|
||||
if h.handshakePackets != nil {
|
||||
packetsInFlight += h.handshakePackets.history.Len()
|
||||
}
|
||||
if h.initialPackets != nil {
|
||||
packetsInFlight += h.initialPackets.history.Len()
|
||||
}
|
||||
return packetsInFlight
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacket(p *Packet) {
|
||||
h.bytesSent += p.Length
|
||||
// For the client, drop the Initial packet number space when the first Handshake packet is sent.
|
||||
if h.perspective == protocol.PerspectiveClient && p.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil {
|
||||
h.dropPackets(protocol.EncryptionInitial)
|
||||
}
|
||||
isAckEliciting := h.sentPacketImpl(p)
|
||||
if isAckEliciting {
|
||||
h.getPacketNumberSpace(p.EncryptionLevel).history.SentAckElicitingPacket(p)
|
||||
} else {
|
||||
h.getPacketNumberSpace(p.EncryptionLevel).history.SentNonAckElicitingPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime)
|
||||
putPacket(p)
|
||||
p = nil //nolint:ineffassign // This is just to be on the safe side.
|
||||
}
|
||||
if h.tracer != nil && isAckEliciting {
|
||||
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
|
||||
}
|
||||
if isAckEliciting || !h.peerCompletedAddressValidation {
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return h.initialPackets
|
||||
case protocol.EncryptionHandshake:
|
||||
return h.handshakePackets
|
||||
case protocol.Encryption0RTT, protocol.Encryption1RTT:
|
||||
return h.appDataPackets
|
||||
default:
|
||||
panic("invalid packet number space")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ {
|
||||
pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel)
|
||||
|
||||
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
|
||||
for p := utils.Max(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ {
|
||||
h.logger.Debugf("Skipping packet number %d", p)
|
||||
}
|
||||
}
|
||||
|
||||
pnSpace.largestSent = packet.PacketNumber
|
||||
isAckEliciting := len(packet.Frames) > 0
|
||||
|
||||
if isAckEliciting {
|
||||
pnSpace.lastAckElicitingPacketTime = packet.SendTime
|
||||
packet.includedInBytesInFlight = true
|
||||
h.bytesInFlight += packet.Length
|
||||
if h.numProbesToSend > 0 {
|
||||
h.numProbesToSend--
|
||||
}
|
||||
}
|
||||
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isAckEliciting)
|
||||
|
||||
return isAckEliciting
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
|
||||
largestAcked := ack.LargestAcked()
|
||||
if largestAcked > pnSpace.largestSent {
|
||||
return false, &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "received ACK for an unsent packet",
|
||||
}
|
||||
}
|
||||
|
||||
pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked)
|
||||
|
||||
// Servers complete address validation when a protected packet is received.
|
||||
if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation &&
|
||||
(encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) {
|
||||
h.peerCompletedAddressValidation = true
|
||||
h.logger.Debugf("Peer doesn't await address validation any longer.")
|
||||
// Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets.
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
|
||||
priorInFlight := h.bytesInFlight
|
||||
ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel)
|
||||
if err != nil || len(ackedPackets) == 0 {
|
||||
return false, err
|
||||
}
|
||||
// update the RTT, if the largest acked is newly acknowledged
|
||||
if len(ackedPackets) > 0 {
|
||||
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() {
|
||||
// don't use the ack delay for Initial and Handshake packets
|
||||
var ackDelay time.Duration
|
||||
if encLevel == protocol.Encryption1RTT {
|
||||
ackDelay = utils.Min(ack.DelayTime, h.rttStats.MaxAckDelay())
|
||||
}
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||
}
|
||||
h.congestion.MaybeExitSlowStart()
|
||||
}
|
||||
}
|
||||
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
|
||||
return false, err
|
||||
}
|
||||
var acked1RTTPacket bool
|
||||
for _, p := range ackedPackets {
|
||||
if p.includedInBytesInFlight && !p.declaredLost {
|
||||
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
|
||||
}
|
||||
if p.EncryptionLevel == protocol.Encryption1RTT {
|
||||
acked1RTTPacket = true
|
||||
}
|
||||
h.removeFromBytesInFlight(p)
|
||||
putPacket(p)
|
||||
}
|
||||
// After this point, we must not use ackedPackets any longer!
|
||||
// We've already returned the buffers.
|
||||
ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side.
|
||||
|
||||
// Reset the pto_count unless the client is unsure if the server has validated the client's address.
|
||||
if h.peerCompletedAddressValidation {
|
||||
if h.tracer != nil && h.ptoCount != 0 {
|
||||
h.tracer.UpdatedPTOCount(0)
|
||||
}
|
||||
h.ptoCount = 0
|
||||
}
|
||||
h.numProbesToSend = 0
|
||||
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
|
||||
}
|
||||
|
||||
pnSpace.history.DeleteOldPackets(rcvTime)
|
||||
h.setLossDetectionTimer()
|
||||
return acked1RTTPacket, nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
|
||||
return h.lowestNotConfirmedAcked
|
||||
}
|
||||
|
||||
// Packets are returned in ascending packet number order.
|
||||
func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
h.ackedPackets = h.ackedPackets[:0]
|
||||
ackRangeIndex := 0
|
||||
lowestAcked := ack.LowestAcked()
|
||||
largestAcked := ack.LargestAcked()
|
||||
err := pnSpace.history.Iterate(func(p *Packet) (bool, error) {
|
||||
// Ignore packets below the lowest acked
|
||||
if p.PacketNumber < lowestAcked {
|
||||
return true, nil
|
||||
}
|
||||
// Break after largest acked is reached
|
||||
if p.PacketNumber > largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ack.HasMissingRanges() {
|
||||
ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
|
||||
|
||||
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
|
||||
ackRangeIndex++
|
||||
ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
|
||||
}
|
||||
|
||||
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
|
||||
return true, nil
|
||||
}
|
||||
if p.PacketNumber > ackRange.Largest {
|
||||
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
|
||||
}
|
||||
}
|
||||
if p.skippedPacket {
|
||||
return false, &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
|
||||
}
|
||||
}
|
||||
h.ackedPackets = append(h.ackedPackets, p)
|
||||
return true, nil
|
||||
})
|
||||
if h.logger.Debug() && len(h.ackedPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
|
||||
for i, p := range h.ackedPackets {
|
||||
pns[i] = p.PacketNumber
|
||||
}
|
||||
h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns)
|
||||
}
|
||||
|
||||
for _, p := range h.ackedPackets {
|
||||
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
|
||||
h.lowestNotConfirmedAcked = utils.Max(h.lowestNotConfirmedAcked, p.LargestAcked+1)
|
||||
}
|
||||
|
||||
for _, f := range p.Frames {
|
||||
if f.OnAcked != nil {
|
||||
f.OnAcked(f.Frame)
|
||||
}
|
||||
}
|
||||
if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
return h.ackedPackets, err
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) {
|
||||
var encLevel protocol.EncryptionLevel
|
||||
var lossTime time.Time
|
||||
|
||||
if h.initialPackets != nil {
|
||||
lossTime = h.initialPackets.lossTime
|
||||
encLevel = protocol.EncryptionInitial
|
||||
}
|
||||
if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) {
|
||||
lossTime = h.handshakePackets.lossTime
|
||||
encLevel = protocol.EncryptionHandshake
|
||||
}
|
||||
if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) {
|
||||
lossTime = h.appDataPackets.lossTime
|
||||
encLevel = protocol.Encryption1RTT
|
||||
}
|
||||
return lossTime, encLevel
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration {
|
||||
pto := h.rttStats.PTO(includeMaxAckDelay) << h.ptoCount
|
||||
if pto > maxPTODuration || pto <= 0 {
|
||||
return maxPTODuration
|
||||
}
|
||||
return pto
|
||||
}
|
||||
|
||||
// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime
|
||||
func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) {
|
||||
// We only send application data probe packets once the handshake is confirmed,
|
||||
// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
|
||||
if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() {
|
||||
if h.peerCompletedAddressValidation {
|
||||
return
|
||||
}
|
||||
t := time.Now().Add(h.getScaledPTO(false))
|
||||
if h.initialPackets != nil {
|
||||
return t, protocol.EncryptionInitial, true
|
||||
}
|
||||
return t, protocol.EncryptionHandshake, true
|
||||
}
|
||||
|
||||
if h.initialPackets != nil {
|
||||
encLevel = protocol.EncryptionInitial
|
||||
if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() {
|
||||
pto = t.Add(h.getScaledPTO(false))
|
||||
}
|
||||
}
|
||||
if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
|
||||
t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false))
|
||||
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
|
||||
pto = t
|
||||
encLevel = protocol.EncryptionHandshake
|
||||
}
|
||||
}
|
||||
if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
|
||||
t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true))
|
||||
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
|
||||
pto = t
|
||||
encLevel = protocol.Encryption1RTT
|
||||
}
|
||||
}
|
||||
return pto, encLevel, true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
|
||||
if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() {
|
||||
return true
|
||||
}
|
||||
if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) hasOutstandingPackets() bool {
|
||||
return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) setLossDetectionTimer() {
|
||||
oldAlarm := h.alarm // only needed in case tracing is enabled
|
||||
lossTime, encLevel := h.getLossTimeAndSpace()
|
||||
if !lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = lossTime
|
||||
if h.tracer != nil && h.alarm != oldAlarm {
|
||||
h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel the alarm if amplification limited.
|
||||
if h.isAmplificationLimited() {
|
||||
h.alarm = time.Time{}
|
||||
if !oldAlarm.IsZero() {
|
||||
h.logger.Debugf("Canceling loss detection timer. Amplification limited.")
|
||||
if h.tracer != nil {
|
||||
h.tracer.LossTimerCanceled()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation {
|
||||
h.alarm = time.Time{}
|
||||
if !oldAlarm.IsZero() {
|
||||
h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
|
||||
if h.tracer != nil {
|
||||
h.tracer.LossTimerCanceled()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PTO alarm
|
||||
ptoTime, encLevel, ok := h.getPTOTimeAndSpace()
|
||||
if !ok {
|
||||
if !oldAlarm.IsZero() {
|
||||
h.alarm = time.Time{}
|
||||
h.logger.Debugf("Canceling loss detection timer. No PTO needed..")
|
||||
if h.tracer != nil {
|
||||
h.tracer.LossTimerCanceled()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
h.alarm = ptoTime
|
||||
if h.tracer != nil && h.alarm != oldAlarm {
|
||||
h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
pnSpace.lossTime = time.Time{}
|
||||
|
||||
maxRTT := float64(utils.Max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||
lossDelay := time.Duration(timeThreshold * maxRTT)
|
||||
|
||||
// Minimum time of granularity before packets are deemed lost.
|
||||
lossDelay = utils.Max(lossDelay, protocol.TimerGranularity)
|
||||
|
||||
// Packets sent before this time are deemed lost.
|
||||
lostSendTime := now.Add(-lossDelay)
|
||||
|
||||
priorInFlight := h.bytesInFlight
|
||||
return pnSpace.history.Iterate(func(p *Packet) (bool, error) {
|
||||
if p.PacketNumber > pnSpace.largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
if p.declaredLost || p.skippedPacket {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var packetLost bool
|
||||
if p.SendTime.Before(lostSendTime) {
|
||||
packetLost = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
|
||||
}
|
||||
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
|
||||
packetLost = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
|
||||
}
|
||||
} else if pnSpace.lossTime.IsZero() {
|
||||
// Note: This conditional is only entered once per call
|
||||
lossTime := p.SendTime.Add(lossDelay)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime)
|
||||
}
|
||||
pnSpace.lossTime = lossTime
|
||||
}
|
||||
if packetLost {
|
||||
p = pnSpace.history.DeclareLost(p)
|
||||
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
|
||||
h.removeFromBytesInFlight(p)
|
||||
h.queueFramesForRetransmission(p)
|
||||
if !p.IsPathMTUProbePacket {
|
||||
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) OnLossDetectionTimeout() error {
|
||||
defer h.setLossDetectionTimer()
|
||||
earliestLossTime, encLevel := h.getLossTimeAndSpace()
|
||||
if !earliestLossTime.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime)
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
|
||||
}
|
||||
// Early retransmit or time loss detection
|
||||
return h.detectLostPackets(time.Now(), encLevel)
|
||||
}
|
||||
|
||||
// PTO
|
||||
// When all outstanding are acknowledged, the alarm is canceled in
|
||||
// setLossDetectionTimer. This doesn't reset the timer in the session though.
|
||||
// When OnAlarm is called, we therefore need to make sure that there are
|
||||
// actually packets outstanding.
|
||||
if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation {
|
||||
h.ptoCount++
|
||||
h.numProbesToSend++
|
||||
if h.initialPackets != nil {
|
||||
h.ptoMode = SendPTOInitial
|
||||
} else if h.handshakePackets != nil {
|
||||
h.ptoMode = SendPTOHandshake
|
||||
} else {
|
||||
return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, encLevel, ok := h.getPTOTimeAndSpace()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation {
|
||||
return nil
|
||||
}
|
||||
h.ptoCount++
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount)
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel)
|
||||
h.tracer.UpdatedPTOCount(h.ptoCount)
|
||||
}
|
||||
h.numProbesToSend += 2
|
||||
//nolint:exhaustive // We never arm a PTO timer for 0-RTT packets.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.ptoMode = SendPTOInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
h.ptoMode = SendPTOHandshake
|
||||
case protocol.Encryption1RTT:
|
||||
// skip a packet number in order to elicit an immediate ACK
|
||||
_ = h.PopPacketNumber(protocol.Encryption1RTT)
|
||||
h.ptoMode = SendPTOAppData
|
||||
default:
|
||||
return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
|
||||
return h.alarm
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
|
||||
var lowestUnacked protocol.PacketNumber
|
||||
if p := pnSpace.history.FirstOutstanding(); p != nil {
|
||||
lowestUnacked = p.PacketNumber
|
||||
} else {
|
||||
lowestUnacked = pnSpace.largestAcked + 1
|
||||
}
|
||||
|
||||
pn := pnSpace.pns.Peek()
|
||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, lowestUnacked)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
|
||||
return h.getPacketNumberSpace(encLevel).pns.Pop()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendMode() SendMode {
|
||||
numTrackedPackets := h.appDataPackets.history.Len()
|
||||
if h.initialPackets != nil {
|
||||
numTrackedPackets += h.initialPackets.history.Len()
|
||||
}
|
||||
if h.handshakePackets != nil {
|
||||
numTrackedPackets += h.handshakePackets.history.Len()
|
||||
}
|
||||
|
||||
if h.isAmplificationLimited() {
|
||||
h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent)
|
||||
return SendNone
|
||||
}
|
||||
// Don't send any packets if we're keeping track of the maximum number of packets.
|
||||
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
|
||||
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
|
||||
// but still allow sending of retransmissions and ACKs.
|
||||
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
|
||||
}
|
||||
return SendNone
|
||||
}
|
||||
if h.numProbesToSend > 0 {
|
||||
return h.ptoMode
|
||||
}
|
||||
// Only send ACKs if we're congestion limited.
|
||||
if !h.congestion.CanSend(h.bytesInFlight) {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow())
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
return SendAny
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
||||
return h.congestion.TimeUntilSend(h.bytesInFlight)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) HasPacingBudget() bool {
|
||||
return h.congestion.HasPacingBudget()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||
h.congestion.SetMaxDatagramSize(s)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) isAmplificationLimited() bool {
|
||||
if h.peerAddressValidated {
|
||||
return false
|
||||
}
|
||||
return h.bytesSent >= amplificationFactor*h.bytesReceived
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
p := pnSpace.history.FirstOutstanding()
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
h.queueFramesForRetransmission(p)
|
||||
// TODO: don't declare the packet lost here.
|
||||
// Keep track of acknowledged frames instead.
|
||||
h.removeFromBytesInFlight(p)
|
||||
pnSpace.history.DeclareLost(p)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) {
|
||||
if len(p.Frames) == 0 {
|
||||
panic("no frames")
|
||||
}
|
||||
for _, f := range p.Frames {
|
||||
f.OnLost(f.Frame)
|
||||
}
|
||||
p.Frames = nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ResetForRetry() error {
|
||||
h.bytesInFlight = 0
|
||||
var firstPacketSendTime time.Time
|
||||
h.initialPackets.history.Iterate(func(p *Packet) (bool, error) {
|
||||
if firstPacketSendTime.IsZero() {
|
||||
firstPacketSendTime = p.SendTime
|
||||
}
|
||||
if p.declaredLost || p.skippedPacket {
|
||||
return true, nil
|
||||
}
|
||||
h.queueFramesForRetransmission(p)
|
||||
return true, nil
|
||||
})
|
||||
// All application data packets sent at this point are 0-RTT packets.
|
||||
// In the case of a Retry, we can assume that the server dropped all of them.
|
||||
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) {
|
||||
if !p.declaredLost && !p.skippedPacket {
|
||||
h.queueFramesForRetransmission(p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
|
||||
// Otherwise, we don't know which Initial the Retry was sent in response to.
|
||||
if h.ptoCount == 0 {
|
||||
// Don't set the RTT to a value lower than 5ms here.
|
||||
now := time.Now()
|
||||
h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
|
||||
}
|
||||
}
|
||||
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats)
|
||||
h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats)
|
||||
oldAlarm := h.alarm
|
||||
h.alarm = time.Time{}
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedPTOCount(0)
|
||||
if !oldAlarm.IsZero() {
|
||||
h.tracer.LossTimerCanceled()
|
||||
}
|
||||
}
|
||||
h.ptoCount = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeConfirmed() {
|
||||
h.handshakeConfirmed = true
|
||||
// We don't send PTOs for application data packets before the handshake completes.
|
||||
// Make sure the timer is armed now, if necessary.
|
||||
h.setLossDetectionTimer()
|
||||
}
|
163
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
163
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
|
@ -0,0 +1,163 @@
|
|||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||
)
|
||||
|
||||
type sentPacketHistory struct {
|
||||
rttStats *utils.RTTStats
|
||||
outstandingPacketList *list.List[*Packet]
|
||||
etcPacketList *list.List[*Packet]
|
||||
packetMap map[protocol.PacketNumber]*list.Element[*Packet]
|
||||
highestSent protocol.PacketNumber
|
||||
}
|
||||
|
||||
var packetElementPool sync.Pool
|
||||
|
||||
func init() {
|
||||
packetElementPool = *list.NewPool[*Packet]()
|
||||
}
|
||||
|
||||
func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory {
|
||||
return &sentPacketHistory{
|
||||
rttStats: rttStats,
|
||||
outstandingPacketList: list.NewWithPool[*Packet](&packetElementPool),
|
||||
etcPacketList: list.NewWithPool[*Packet](&packetElementPool),
|
||||
packetMap: make(map[protocol.PacketNumber]*list.Element[*Packet]),
|
||||
highestSent: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) {
|
||||
h.registerSentPacket(pn, encLevel, t)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentAckElicitingPacket(p *Packet) {
|
||||
h.registerSentPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime)
|
||||
|
||||
var el *list.Element[*Packet]
|
||||
if p.outstanding() {
|
||||
el = h.outstandingPacketList.PushBack(p)
|
||||
} else {
|
||||
el = h.etcPacketList.PushBack(p)
|
||||
}
|
||||
h.packetMap[p.PacketNumber] = el
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) registerSentPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) {
|
||||
if pn <= h.highestSent {
|
||||
panic("non-sequential packet number use")
|
||||
}
|
||||
// Skipped packet numbers.
|
||||
for p := h.highestSent + 1; p < pn; p++ {
|
||||
el := h.etcPacketList.PushBack(&Packet{
|
||||
PacketNumber: p,
|
||||
EncryptionLevel: encLevel,
|
||||
SendTime: t,
|
||||
skippedPacket: true,
|
||||
})
|
||||
h.packetMap[p] = el
|
||||
}
|
||||
h.highestSent = pn
|
||||
}
|
||||
|
||||
// Iterate iterates through all packets.
|
||||
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
|
||||
cont := true
|
||||
outstandingEl := h.outstandingPacketList.Front()
|
||||
etcEl := h.etcPacketList.Front()
|
||||
var el *list.Element[*Packet]
|
||||
// whichever has the next packet number is returned first
|
||||
for cont {
|
||||
if outstandingEl == nil || (etcEl != nil && etcEl.Value.PacketNumber < outstandingEl.Value.PacketNumber) {
|
||||
el = etcEl
|
||||
} else {
|
||||
el = outstandingEl
|
||||
}
|
||||
if el == nil {
|
||||
return nil
|
||||
}
|
||||
if el == outstandingEl {
|
||||
outstandingEl = outstandingEl.Next()
|
||||
} else {
|
||||
etcEl = etcEl.Next()
|
||||
}
|
||||
var err error
|
||||
cont, err = cb(el.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FirstOutstanding returns the first outstanding packet.
|
||||
func (h *sentPacketHistory) FirstOutstanding() *Packet {
|
||||
el := h.outstandingPacketList.Front()
|
||||
if el == nil {
|
||||
return nil
|
||||
}
|
||||
return el.Value
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Len() int {
|
||||
return len(h.packetMap)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
|
||||
el, ok := h.packetMap[p]
|
||||
if !ok {
|
||||
return fmt.Errorf("packet %d not found in sent packet history", p)
|
||||
}
|
||||
el.List().Remove(el)
|
||||
delete(h.packetMap, p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) HasOutstandingPackets() bool {
|
||||
return h.outstandingPacketList.Len() > 0
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) DeleteOldPackets(now time.Time) {
|
||||
maxAge := 3 * h.rttStats.PTO(false)
|
||||
var nextEl *list.Element[*Packet]
|
||||
// we don't iterate outstandingPacketList, as we should not delete outstanding packets.
|
||||
// being outstanding for more than 3*PTO should only happen in the case of drastic RTT changes.
|
||||
for el := h.etcPacketList.Front(); el != nil; el = nextEl {
|
||||
nextEl = el.Next()
|
||||
p := el.Value
|
||||
if p.SendTime.After(now.Add(-maxAge)) {
|
||||
break
|
||||
}
|
||||
delete(h.packetMap, p.PacketNumber)
|
||||
h.etcPacketList.Remove(el)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) DeclareLost(p *Packet) *Packet {
|
||||
el, ok := h.packetMap[p.PacketNumber]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
el.List().Remove(el)
|
||||
p.declaredLost = true
|
||||
// move it to the correct position in the etc list (based on the packet number)
|
||||
for el = h.etcPacketList.Back(); el != nil; el = el.Prev() {
|
||||
if el.Value.PacketNumber < p.PacketNumber {
|
||||
break
|
||||
}
|
||||
}
|
||||
if el == nil {
|
||||
el = h.etcPacketList.PushFront(p)
|
||||
} else {
|
||||
el = h.etcPacketList.InsertAfter(p, el)
|
||||
}
|
||||
h.packetMap[p.PacketNumber] = el
|
||||
return el.Value
|
||||
}
|
25
vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go
generated
vendored
Normal file
25
vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go
generated
vendored
Normal file
|
@ -0,0 +1,25 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Bandwidth of a connection
|
||||
type Bandwidth uint64
|
||||
|
||||
const infBandwidth Bandwidth = math.MaxUint64
|
||||
|
||||
const (
|
||||
// BitsPerSecond is 1 bit per second
|
||||
BitsPerSecond Bandwidth = 1
|
||||
// BytesPerSecond is 1 byte per second
|
||||
BytesPerSecond = 8 * BitsPerSecond
|
||||
)
|
||||
|
||||
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
|
||||
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
|
||||
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
|
||||
}
|
18
vendor/github.com/quic-go/quic-go/internal/congestion/clock.go
generated
vendored
Normal file
18
vendor/github.com/quic-go/quic-go/internal/congestion/clock.go
generated
vendored
Normal file
|
@ -0,0 +1,18 @@
|
|||
package congestion
|
||||
|
||||
import "time"
|
||||
|
||||
// A Clock returns the current time
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// DefaultClock implements the Clock interface using the Go stdlib clock.
|
||||
type DefaultClock struct{}
|
||||
|
||||
var _ Clock = DefaultClock{}
|
||||
|
||||
// Now gets the current time
|
||||
func (DefaultClock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
214
vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go
generated
vendored
Normal file
214
vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go
generated
vendored
Normal file
|
@ -0,0 +1,214 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
||||
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
|
||||
|
||||
// Constants based on TCP defaults.
|
||||
// The following constants are in 2^10 fractions of a second instead of ms to
|
||||
// allow a 10 shift right to divide.
|
||||
|
||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||
const (
|
||||
cubeScale = 40
|
||||
cubeCongestionWindowScale = 410
|
||||
cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
|
||||
// TODO: when re-enabling cubic, make sure to use the actual packet size here
|
||||
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
|
||||
)
|
||||
|
||||
const defaultNumConnections = 1
|
||||
|
||||
// Default Cubic backoff factor
|
||||
const beta float32 = 0.7
|
||||
|
||||
// Additional backoff factor when loss occurs in the concave part of the Cubic
|
||||
// curve. This additional backoff factor is expected to give up bandwidth to
|
||||
// new concurrent flows and speed up convergence.
|
||||
const betaLastMax float32 = 0.85
|
||||
|
||||
// Cubic implements the cubic algorithm from TCP
|
||||
type Cubic struct {
|
||||
clock Clock
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// Time when this cycle started, after last loss event.
|
||||
epoch time.Time
|
||||
|
||||
// Max congestion window used just before last loss event.
|
||||
// Note: to improve fairness to other streams an additional back off is
|
||||
// applied to this value if the new value is below our latest value.
|
||||
lastMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Number of acked bytes since the cycle started (epoch).
|
||||
ackedBytesCount protocol.ByteCount
|
||||
|
||||
// TCP Reno equivalent congestion window in packets.
|
||||
estimatedTCPcongestionWindow protocol.ByteCount
|
||||
|
||||
// Origin point of cubic function.
|
||||
originPointCongestionWindow protocol.ByteCount
|
||||
|
||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||
timeToOriginPoint uint32
|
||||
|
||||
// Last congestion window in packets computed by cubic function.
|
||||
lastTargetCongestionWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
// NewCubic returns a new Cubic instance
|
||||
func NewCubic(clock Clock) *Cubic {
|
||||
c := &Cubic{
|
||||
clock: clock,
|
||||
numConnections: defaultNumConnections,
|
||||
}
|
||||
c.Reset()
|
||||
return c
|
||||
}
|
||||
|
||||
// Reset is called after a timeout to reset the cubic state
|
||||
func (c *Cubic) Reset() {
|
||||
c.epoch = time.Time{}
|
||||
c.lastMaxCongestionWindow = 0
|
||||
c.ackedBytesCount = 0
|
||||
c.estimatedTCPcongestionWindow = 0
|
||||
c.originPointCongestionWindow = 0
|
||||
c.timeToOriginPoint = 0
|
||||
c.lastTargetCongestionWindow = 0
|
||||
}
|
||||
|
||||
func (c *Cubic) alpha() float32 {
|
||||
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
|
||||
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
|
||||
// We derive the equivalent alpha for an N-connection emulation as:
|
||||
b := c.beta()
|
||||
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
|
||||
}
|
||||
|
||||
func (c *Cubic) beta() float32 {
|
||||
// kNConnectionBeta is the backoff factor after loss for our N-connection
|
||||
// emulation, which emulates the effective backoff of an ensemble of N
|
||||
// TCP-Reno connections on a single loss event. The effective multiplier is
|
||||
// computed as:
|
||||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
func (c *Cubic) betaLastMax() float32 {
|
||||
// betaLastMax is the additional backoff factor after loss for our
|
||||
// N-connection emulation, which emulates the additional backoff of
|
||||
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||
// effective multiplier is computed as:
|
||||
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||
// the available congestion window. Resets Cubic state during quiescence.
|
||||
func (c *Cubic) OnApplicationLimited() {
|
||||
// When sender is not using the available congestion window, the window does
|
||||
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||
// using the entire window during the time since the beginning of the current
|
||||
// "epoch" (the end of the last loss recovery period). Since
|
||||
// application-limited periods break this assumption, we reset the epoch when
|
||||
// in such a period. This reset effectively freezes congestion window growth
|
||||
// through application-limited periods and allows Cubic growth to continue
|
||||
// when the entire window is being used.
|
||||
c.epoch = time.Time{}
|
||||
}
|
||||
|
||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||
// a loss event. Returns the new congestion window in packets. The new
|
||||
// congestion window is a multiplicative decrease of our current window.
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
|
||||
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
|
||||
// We never reached the old max, so assume we are competing with another
|
||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||
} else {
|
||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||
}
|
||||
c.epoch = time.Time{} // Reset time.
|
||||
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||
}
|
||||
|
||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||
// Returns the new congestion window in packets. The new congestion window
|
||||
// follows a cubic function that depends on the time passed since last
|
||||
// packet loss.
|
||||
func (c *Cubic) CongestionWindowAfterAck(
|
||||
ackedBytes protocol.ByteCount,
|
||||
currentCongestionWindow protocol.ByteCount,
|
||||
delayMin time.Duration,
|
||||
eventTime time.Time,
|
||||
) protocol.ByteCount {
|
||||
c.ackedBytesCount += ackedBytes
|
||||
|
||||
if c.epoch.IsZero() {
|
||||
// First ACK after a loss event.
|
||||
c.epoch = eventTime // Start of epoch.
|
||||
c.ackedBytesCount = ackedBytes // Reset count.
|
||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||
c.timeToOriginPoint = 0
|
||||
c.originPointCongestionWindow = currentCongestionWindow
|
||||
} else {
|
||||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||
}
|
||||
}
|
||||
|
||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||
// the round trip time in account. This is done to allow us to use shift as a
|
||||
// divide operator.
|
||||
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||
|
||||
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
|
||||
var targetCongestionWindow protocol.ByteCount
|
||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||
} else {
|
||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||
}
|
||||
// Limit the CWND increase to half the acked bytes.
|
||||
targetCongestionWindow = utils.Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||
|
||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||
// time we ack an estimated tcp window of bytes. For small
|
||||
// congestion windows (less than 25), the formula below will
|
||||
// increase slightly slower than linearly per estimated tcp window
|
||||
// of bytes.
|
||||
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
|
||||
c.ackedBytesCount = 0
|
||||
|
||||
// We have a new cubic congestion window.
|
||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||
|
||||
// Compute target congestion_window based on cubic target and estimated TCP
|
||||
// congestion_window, use highest (fastest).
|
||||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||
}
|
||||
return targetCongestionWindow
|
||||
}
|
||||
|
||||
// SetNumConnections sets the number of emulated connections
|
||||
func (c *Cubic) SetNumConnections(n int) {
|
||||
c.numConnections = n
|
||||
}
|
316
vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go
generated
vendored
Normal file
316
vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go
generated
vendored
Normal file
|
@ -0,0 +1,316 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
|
||||
// Used in QUIC for congestion window computations in bytes.
|
||||
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
|
||||
maxBurstPackets = 3
|
||||
renoBeta = 0.7 // Reno backoff factor.
|
||||
minCongestionWindowPackets = 2
|
||||
initialCongestionWindow = 32
|
||||
)
|
||||
|
||||
type cubicSender struct {
|
||||
hybridSlowStart HybridSlowStart
|
||||
rttStats *utils.RTTStats
|
||||
cubic *Cubic
|
||||
pacer *pacer
|
||||
clock Clock
|
||||
|
||||
reno bool
|
||||
|
||||
// Track the largest packet that has been sent.
|
||||
largestSentPacketNumber protocol.PacketNumber
|
||||
|
||||
// Track the largest packet that has been acked.
|
||||
largestAckedPacketNumber protocol.PacketNumber
|
||||
|
||||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||
largestSentAtLastCutback protocol.PacketNumber
|
||||
|
||||
// Whether the last loss event caused us to exit slowstart.
|
||||
// Used for stats collection of slowstartPacketsLost
|
||||
lastCutbackExitedSlowstart bool
|
||||
|
||||
// Congestion window in bytes.
|
||||
congestionWindow protocol.ByteCount
|
||||
|
||||
// Slow start congestion window in bytes, aka ssthresh.
|
||||
slowStartThreshold protocol.ByteCount
|
||||
|
||||
// ACK counter for the Reno implementation.
|
||||
numAckedPackets uint64
|
||||
|
||||
initialCongestionWindow protocol.ByteCount
|
||||
initialMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
maxDatagramSize protocol.ByteCount
|
||||
|
||||
lastState logging.CongestionState
|
||||
tracer logging.ConnectionTracer
|
||||
}
|
||||
|
||||
var (
|
||||
_ SendAlgorithm = &cubicSender{}
|
||||
_ SendAlgorithmWithDebugInfos = &cubicSender{}
|
||||
)
|
||||
|
||||
// NewCubicSender makes a new cubic sender
|
||||
func NewCubicSender(
|
||||
clock Clock,
|
||||
rttStats *utils.RTTStats,
|
||||
initialMaxDatagramSize protocol.ByteCount,
|
||||
reno bool,
|
||||
tracer logging.ConnectionTracer,
|
||||
) *cubicSender {
|
||||
return newCubicSender(
|
||||
clock,
|
||||
rttStats,
|
||||
reno,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow*initialMaxDatagramSize,
|
||||
protocol.MaxCongestionWindowPackets*initialMaxDatagramSize,
|
||||
tracer,
|
||||
)
|
||||
}
|
||||
|
||||
func newCubicSender(
|
||||
clock Clock,
|
||||
rttStats *utils.RTTStats,
|
||||
reno bool,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow,
|
||||
initialMaxCongestionWindow protocol.ByteCount,
|
||||
tracer logging.ConnectionTracer,
|
||||
) *cubicSender {
|
||||
c := &cubicSender{
|
||||
rttStats: rttStats,
|
||||
largestSentPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestAckedPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestSentAtLastCutback: protocol.InvalidPacketNumber,
|
||||
initialCongestionWindow: initialCongestionWindow,
|
||||
initialMaxCongestionWindow: initialMaxCongestionWindow,
|
||||
congestionWindow: initialCongestionWindow,
|
||||
slowStartThreshold: protocol.MaxByteCount,
|
||||
cubic: NewCubic(clock),
|
||||
clock: clock,
|
||||
reno: reno,
|
||||
tracer: tracer,
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
}
|
||||
c.pacer = newPacer(c.BandwidthEstimate)
|
||||
if c.tracer != nil {
|
||||
c.lastState = logging.CongestionStateSlowStart
|
||||
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time {
|
||||
return c.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
func (c *cubicSender) HasPacingBudget() bool {
|
||||
return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize
|
||||
}
|
||||
|
||||
func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {
|
||||
return c.maxDatagramSize * protocol.MaxCongestionWindowPackets
|
||||
}
|
||||
|
||||
func (c *cubicSender) minCongestionWindow() protocol.ByteCount {
|
||||
return c.maxDatagramSize * minCongestionWindowPackets
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketSent(
|
||||
sentTime time.Time,
|
||||
_ protocol.ByteCount,
|
||||
packetNumber protocol.PacketNumber,
|
||||
bytes protocol.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
c.pacer.SentPacket(sentTime, bytes)
|
||||
if !isRetransmittable {
|
||||
return
|
||||
}
|
||||
c.largestSentPacketNumber = packetNumber
|
||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||
}
|
||||
|
||||
func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool {
|
||||
return bytesInFlight < c.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (c *cubicSender) InRecovery() bool {
|
||||
return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
|
||||
}
|
||||
|
||||
func (c *cubicSender) InSlowStart() bool {
|
||||
return c.GetCongestionWindow() < c.slowStartThreshold
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
|
||||
return c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) MaybeExitSlowStart() {
|
||||
if c.InSlowStart() &&
|
||||
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
|
||||
// exit slow start
|
||||
c.slowStartThreshold = c.congestionWindow
|
||||
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketAcked(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
c.largestAckedPacketNumber = utils.Max(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||
if c.InRecovery() {
|
||||
return
|
||||
}
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||
if c.InSlowStart() {
|
||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
|
||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||
// already sent should be treated as a single loss event, since it's expected.
|
||||
if packetNumber <= c.largestSentAtLastCutback {
|
||||
return
|
||||
}
|
||||
c.lastCutbackExitedSlowstart = c.InSlowStart()
|
||||
c.maybeTraceStateChange(logging.CongestionStateRecovery)
|
||||
|
||||
if c.reno {
|
||||
c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta)
|
||||
} else {
|
||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||
}
|
||||
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
|
||||
c.congestionWindow = minCwnd
|
||||
}
|
||||
c.slowStartThreshold = c.congestionWindow
|
||||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||
// reset packet count from congestion avoidance mode. We start
|
||||
// counting again when we're out of recovery.
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
|
||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||
// represents, but quic has a separate ack for each packet.
|
||||
func (c *cubicSender) maybeIncreaseCwnd(
|
||||
_ protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
// Do not increase the congestion window unless the sender is close to using
|
||||
// the current window.
|
||||
if !c.isCwndLimited(priorInFlight) {
|
||||
c.cubic.OnApplicationLimited()
|
||||
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
|
||||
return
|
||||
}
|
||||
if c.congestionWindow >= c.maxCongestionWindow() {
|
||||
return
|
||||
}
|
||||
if c.InSlowStart() {
|
||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||
c.congestionWindow += c.maxDatagramSize
|
||||
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
|
||||
return
|
||||
}
|
||||
// Congestion avoidance
|
||||
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
||||
if c.reno {
|
||||
// Classic Reno congestion avoidance.
|
||||
c.numAckedPackets++
|
||||
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
|
||||
c.congestionWindow += c.maxDatagramSize
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
} else {
|
||||
c.congestionWindow = utils.Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
|
||||
congestionWindow := c.GetCongestionWindow()
|
||||
if bytesInFlight >= congestionWindow {
|
||||
return true
|
||||
}
|
||||
availableBytes := congestionWindow - bytesInFlight
|
||||
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
|
||||
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
|
||||
}
|
||||
|
||||
// BandwidthEstimate returns the current bandwidth estimate
|
||||
func (c *cubicSender) BandwidthEstimate() Bandwidth {
|
||||
srtt := c.rttStats.SmoothedRTT()
|
||||
if srtt == 0 {
|
||||
// If we haven't measured an rtt, the bandwidth estimate is unknown.
|
||||
return infBandwidth
|
||||
}
|
||||
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
|
||||
}
|
||||
|
||||
// OnRetransmissionTimeout is called on an retransmission timeout
|
||||
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
|
||||
if !packetsRetransmitted {
|
||||
return
|
||||
}
|
||||
c.hybridSlowStart.Restart()
|
||||
c.cubic.Reset()
|
||||
c.slowStartThreshold = c.congestionWindow / 2
|
||||
c.congestionWindow = c.minCongestionWindow()
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when the connection is migrated (?)
|
||||
func (c *cubicSender) OnConnectionMigration() {
|
||||
c.hybridSlowStart.Restart()
|
||||
c.largestSentPacketNumber = protocol.InvalidPacketNumber
|
||||
c.largestAckedPacketNumber = protocol.InvalidPacketNumber
|
||||
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
|
||||
c.lastCutbackExitedSlowstart = false
|
||||
c.cubic.Reset()
|
||||
c.numAckedPackets = 0
|
||||
c.congestionWindow = c.initialCongestionWindow
|
||||
c.slowStartThreshold = c.initialMaxCongestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
|
||||
if c.tracer == nil || new == c.lastState {
|
||||
return
|
||||
}
|
||||
c.tracer.UpdatedCongestionState(new)
|
||||
c.lastState = new
|
||||
}
|
||||
|
||||
func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||
if s < c.maxDatagramSize {
|
||||
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
|
||||
}
|
||||
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
|
||||
c.maxDatagramSize = s
|
||||
if cwndIsMinCwnd {
|
||||
c.congestionWindow = c.minCongestionWindow()
|
||||
}
|
||||
c.pacer.SetMaxDatagramSize(s)
|
||||
}
|
113
vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go
generated
vendored
Normal file
113
vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go
generated
vendored
Normal file
|
@ -0,0 +1,113 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
||||
// tcp_cubic.c.
|
||||
const hybridStartLowWindow = protocol.ByteCount(16)
|
||||
|
||||
// Number of delay samples for detecting the increase of delay.
|
||||
const hybridStartMinSamples = uint32(8)
|
||||
|
||||
// Exit slow start if the min rtt has increased by more than 1/8th.
|
||||
const hybridStartDelayFactorExp = 3 // 2^3 = 8
|
||||
// The original paper specifies 2 and 8ms, but those have changed over time.
|
||||
const (
|
||||
hybridStartDelayMinThresholdUs = int64(4000)
|
||||
hybridStartDelayMaxThresholdUs = int64(16000)
|
||||
)
|
||||
|
||||
// HybridSlowStart implements the TCP hybrid slow start algorithm
|
||||
type HybridSlowStart struct {
|
||||
endPacketNumber protocol.PacketNumber
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
started bool
|
||||
currentMinRTT time.Duration
|
||||
rttSampleCount uint32
|
||||
hystartFound bool
|
||||
}
|
||||
|
||||
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
|
||||
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
|
||||
s.endPacketNumber = lastSent
|
||||
s.currentMinRTT = 0
|
||||
s.rttSampleCount = 0
|
||||
s.started = true
|
||||
}
|
||||
|
||||
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
|
||||
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
|
||||
return s.endPacketNumber < ack
|
||||
}
|
||||
|
||||
// ShouldExitSlowStart should be called on every new ack frame, since a new
|
||||
// RTT measurement can be made then.
|
||||
// rtt: the RTT for this ack packet.
|
||||
// minRTT: is the lowest delay (RTT) we have seen during the session.
|
||||
// congestionWindow: the congestion window in packets.
|
||||
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
|
||||
if !s.started {
|
||||
// Time to start the hybrid slow start.
|
||||
s.StartReceiveRound(s.lastSentPacketNumber)
|
||||
}
|
||||
if s.hystartFound {
|
||||
return true
|
||||
}
|
||||
// Second detection parameter - delay increase detection.
|
||||
// Compare the minimum delay (s.currentMinRTT) of the current
|
||||
// burst of packets relative to the minimum delay during the session.
|
||||
// Note: we only look at the first few(8) packets in each burst, since we
|
||||
// only want to compare the lowest RTT of the burst relative to previous
|
||||
// bursts.
|
||||
s.rttSampleCount++
|
||||
if s.rttSampleCount <= hybridStartMinSamples {
|
||||
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
|
||||
s.currentMinRTT = latestRTT
|
||||
}
|
||||
}
|
||||
// We only need to check this once per round.
|
||||
if s.rttSampleCount == hybridStartMinSamples {
|
||||
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
||||
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
||||
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
||||
minRTTincreaseThresholdUs = utils.Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
||||
minRTTincreaseThreshold := time.Duration(utils.Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
||||
|
||||
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
||||
s.hystartFound = true
|
||||
}
|
||||
}
|
||||
// Exit from slow start if the cwnd is greater than 16 and
|
||||
// increasing delay is found.
|
||||
return congestionWindow >= hybridStartLowWindow && s.hystartFound
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet was sent
|
||||
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
|
||||
s.lastSentPacketNumber = packetNumber
|
||||
}
|
||||
|
||||
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
|
||||
// the round when the final packet of the burst is received and start it on
|
||||
// the next incoming ack.
|
||||
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
|
||||
if s.IsEndOfRound(ackedPacketNumber) {
|
||||
s.started = false
|
||||
}
|
||||
}
|
||||
|
||||
// Started returns true if started
|
||||
func (s *HybridSlowStart) Started() bool {
|
||||
return s.started
|
||||
}
|
||||
|
||||
// Restart the slow start phase
|
||||
func (s *HybridSlowStart) Restart() {
|
||||
s.started = false
|
||||
s.hystartFound = false
|
||||
}
|
28
vendor/github.com/quic-go/quic-go/internal/congestion/interface.go
generated
vendored
Normal file
28
vendor/github.com/quic-go/quic-go/internal/congestion/interface.go
generated
vendored
Normal file
|
@ -0,0 +1,28 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// A SendAlgorithm performs congestion control
|
||||
type SendAlgorithm interface {
|
||||
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time
|
||||
HasPacingBudget() bool
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
|
||||
CanSend(bytesInFlight protocol.ByteCount) bool
|
||||
MaybeExitSlowStart()
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
|
||||
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
|
||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||
SetMaxDatagramSize(protocol.ByteCount)
|
||||
}
|
||||
|
||||
// A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos
|
||||
type SendAlgorithmWithDebugInfos interface {
|
||||
SendAlgorithm
|
||||
InSlowStart() bool
|
||||
InRecovery() bool
|
||||
GetCongestionWindow() protocol.ByteCount
|
||||
}
|
77
vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go
generated
vendored
Normal file
77
vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go
generated
vendored
Normal file
|
@ -0,0 +1,77 @@
|
|||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const maxBurstSizePackets = 10
|
||||
|
||||
// The pacer implements a token bucket pacing algorithm.
|
||||
type pacer struct {
|
||||
budgetAtLastSent protocol.ByteCount
|
||||
maxDatagramSize protocol.ByteCount
|
||||
lastSentTime time.Time
|
||||
getAdjustedBandwidth func() uint64 // in bytes/s
|
||||
}
|
||||
|
||||
func newPacer(getBandwidth func() Bandwidth) *pacer {
|
||||
p := &pacer{
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
getAdjustedBandwidth: func() uint64 {
|
||||
// Bandwidth is in bits/s. We need the value in bytes/s.
|
||||
bw := uint64(getBandwidth() / BytesPerSecond)
|
||||
// Use a slightly higher value than the actual measured bandwidth.
|
||||
// RTT variations then won't result in under-utilization of the congestion window.
|
||||
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
|
||||
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
|
||||
return bw * 5 / 4
|
||||
},
|
||||
}
|
||||
p.budgetAtLastSent = p.maxBurstSize()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) {
|
||||
budget := p.Budget(sendTime)
|
||||
if size > budget {
|
||||
p.budgetAtLastSent = 0
|
||||
} else {
|
||||
p.budgetAtLastSent = budget - size
|
||||
}
|
||||
p.lastSentTime = sendTime
|
||||
}
|
||||
|
||||
func (p *pacer) Budget(now time.Time) protocol.ByteCount {
|
||||
if p.lastSentTime.IsZero() {
|
||||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (protocol.ByteCount(p.getAdjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
return utils.Min(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
func (p *pacer) maxBurstSize() protocol.ByteCount {
|
||||
return utils.Max(
|
||||
protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9,
|
||||
maxBurstSizePackets*p.maxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||
func (p *pacer) TimeUntilSend() time.Time {
|
||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||
return time.Time{}
|
||||
}
|
||||
return p.lastSentTime.Add(utils.Max(
|
||||
protocol.MinPacingDelay,
|
||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond,
|
||||
))
|
||||
}
|
||||
|
||||
func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||
p.maxDatagramSize = s
|
||||
}
|
125
vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
125
vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
|
@ -0,0 +1,125 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type baseFlowController struct {
|
||||
// for sending data
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
lastBlockedAt protocol.ByteCount
|
||||
|
||||
// for receiving data
|
||||
//nolint:structcheck // The mutex is used both by the stream and the connection flow controller
|
||||
mutex sync.Mutex
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowSize protocol.ByteCount
|
||||
maxReceiveWindowSize protocol.ByteCount
|
||||
|
||||
allowWindowIncrease func(size protocol.ByteCount) bool
|
||||
|
||||
epochStartTime time.Time
|
||||
epochStartOffset protocol.ByteCount
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||
// For every offset, it only returns true once.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||
return false, 0
|
||||
}
|
||||
c.lastBlockedAt = c.sendWindow
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
||||
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
|
||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
||||
if offset > c.sendWindow {
|
||||
c.sendWindow = offset
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
|
||||
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
|
||||
if c.bytesSent > c.sendWindow {
|
||||
return 0
|
||||
}
|
||||
return c.sendWindow - c.bytesSent
|
||||
}
|
||||
|
||||
// needs to be called with locked mutex
|
||||
func (c *baseFlowController) addBytesRead(n protocol.ByteCount) {
|
||||
// pretend we sent a WindowUpdate when reading the first byte
|
||||
// this way auto-tuning of the window size already works for the first WindowUpdate
|
||||
if c.bytesRead == 0 {
|
||||
c.startNewAutoTuningEpoch(time.Now())
|
||||
}
|
||||
c.bytesRead += n
|
||||
}
|
||||
|
||||
func (c *baseFlowController) hasWindowUpdate() bool {
|
||||
bytesRemaining := c.receiveWindow - c.bytesRead
|
||||
// update the window when more than the threshold was consumed
|
||||
return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold))
|
||||
}
|
||||
|
||||
// getWindowUpdate updates the receive window, if necessary
|
||||
// it returns the new offset
|
||||
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
|
||||
if !c.hasWindowUpdate() {
|
||||
return 0
|
||||
}
|
||||
|
||||
c.maybeAdjustWindowSize()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowSize
|
||||
return c.receiveWindow
|
||||
}
|
||||
|
||||
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
|
||||
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
|
||||
func (c *baseFlowController) maybeAdjustWindowSize() {
|
||||
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
|
||||
// don't do anything if less than half the window has been consumed
|
||||
if bytesReadInEpoch <= c.receiveWindowSize/2 {
|
||||
return
|
||||
}
|
||||
rtt := c.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
|
||||
now := time.Now()
|
||||
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||
// window is consumed too fast, try to increase the window size
|
||||
newSize := utils.Min(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||
if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) {
|
||||
c.receiveWindowSize = newSize
|
||||
}
|
||||
}
|
||||
c.startNewAutoTuningEpoch(now)
|
||||
}
|
||||
|
||||
func (c *baseFlowController) startNewAutoTuningEpoch(now time.Time) {
|
||||
c.epochStartTime = now
|
||||
c.epochStartOffset = c.bytesRead
|
||||
}
|
||||
|
||||
func (c *baseFlowController) checkFlowControlViolation() bool {
|
||||
return c.highestReceived > c.receiveWindow
|
||||
}
|
112
vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
112
vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
|
@ -0,0 +1,112 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type connectionFlowController struct {
|
||||
baseFlowController
|
||||
|
||||
queueWindowUpdate func()
|
||||
}
|
||||
|
||||
var _ ConnectionFlowController = &connectionFlowController{}
|
||||
|
||||
// NewConnectionFlowController gets a new flow controller for the connection
|
||||
// It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0.
|
||||
func NewConnectionFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(),
|
||||
allowWindowIncrease func(size protocol.ByteCount) bool,
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) ConnectionFlowController {
|
||||
return &connectionFlowController{
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
allowWindowIncrease: allowWindowIncrease,
|
||||
logger: logger,
|
||||
},
|
||||
queueWindowUpdate: queueWindowUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
||||
return c.baseFlowController.sendWindowSize()
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.highestReceived += increment
|
||||
if c.checkFlowControlViolation() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FlowControlError,
|
||||
ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
c.baseFlowController.addBytesRead(n)
|
||||
shouldQueueWindowUpdate := c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if shouldQueueWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if oldWindowSize < c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
||||
|
||||
// EnsureMinimumWindowSize sets a minimum window size
|
||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
if inc > c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
|
||||
newSize := utils.Min(inc, c.maxReceiveWindowSize)
|
||||
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
|
||||
c.receiveWindowSize = newSize
|
||||
}
|
||||
c.startNewAutoTuningEpoch(time.Now())
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Reset rests the flow controller. This happens when 0-RTT is rejected.
|
||||
// All stream data is invalidated, it's if we had never opened a stream and never sent any data.
|
||||
// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet.
|
||||
func (c *connectionFlowController) Reset() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() {
|
||||
return errors.New("flow controller reset after reading data")
|
||||
}
|
||||
c.bytesSent = 0
|
||||
c.lastBlockedAt = 0
|
||||
return nil
|
||||
}
|
42
vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
42
vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
|
@ -0,0 +1,42 @@
|
|||
package flowcontrol
|
||||
|
||||
import "github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
type flowController interface {
|
||||
// for sending
|
||||
SendWindowSize() protocol.ByteCount
|
||||
UpdateSendWindow(protocol.ByteCount)
|
||||
AddBytesSent(protocol.ByteCount)
|
||||
// for receiving
|
||||
AddBytesRead(protocol.ByteCount)
|
||||
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
|
||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
// for receiving
|
||||
// UpdateHighestReceived should be called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream,
|
||||
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||
// Abandon should be called when reading from the stream is aborted early,
|
||||
// and there won't be any further calls to AddBytesRead.
|
||||
Abandon()
|
||||
}
|
||||
|
||||
// The ConnectionFlowController is the flow controller for the connection.
|
||||
type ConnectionFlowController interface {
|
||||
flowController
|
||||
Reset() error
|
||||
}
|
||||
|
||||
type connectionFlowControllerI interface {
|
||||
ConnectionFlowController
|
||||
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
||||
// for sending
|
||||
EnsureMinimumWindowSize(protocol.ByteCount)
|
||||
// for receiving
|
||||
IncrementHighestReceived(protocol.ByteCount) error
|
||||
}
|
149
vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
149
vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
|
@ -0,0 +1,149 @@
|
|||
package flowcontrol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type streamFlowController struct {
|
||||
baseFlowController
|
||||
|
||||
streamID protocol.StreamID
|
||||
|
||||
queueWindowUpdate func()
|
||||
|
||||
connection connectionFlowControllerI
|
||||
|
||||
receivedFinalOffset bool
|
||||
}
|
||||
|
||||
var _ StreamFlowController = &streamFlowController{}
|
||||
|
||||
// NewStreamFlowController gets a new flow controller for a stream
|
||||
func NewStreamFlowController(
|
||||
streamID protocol.StreamID,
|
||||
cfc ConnectionFlowController,
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
initialSendWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(protocol.StreamID),
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
sendWindow: initialSendWindow,
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateHighestReceived updates the highestReceived value, if the offset is higher.
|
||||
func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error {
|
||||
// If the final offset for this stream is already known, check for consistency.
|
||||
if c.receivedFinalOffset {
|
||||
// If we receive another final offset, check that it's the same.
|
||||
if final && offset != c.highestReceived {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset),
|
||||
}
|
||||
}
|
||||
// Check that the offset is below the final offset.
|
||||
if offset > c.highestReceived {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if final {
|
||||
c.receivedFinalOffset = true
|
||||
}
|
||||
if offset == c.highestReceived {
|
||||
return nil
|
||||
}
|
||||
// A higher offset was received before.
|
||||
// This can happen due to reordering.
|
||||
if offset <= c.highestReceived {
|
||||
if final {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
increment := offset - c.highestReceived
|
||||
c.highestReceived = offset
|
||||
if c.checkFlowControlViolation() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FlowControlError,
|
||||
ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow),
|
||||
}
|
||||
}
|
||||
return c.connection.IncrementHighestReceived(increment)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
c.baseFlowController.addBytesRead(n)
|
||||
shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if shouldQueueWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
c.connection.AddBytesRead(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) Abandon() {
|
||||
c.mutex.Lock()
|
||||
unread := c.highestReceived - c.bytesRead
|
||||
c.mutex.Unlock()
|
||||
if unread > 0 {
|
||||
c.connection.AddBytesRead(unread)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesSent(n)
|
||||
c.connection.AddBytesSent(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
return utils.Min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
|
||||
}
|
||||
|
||||
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
|
||||
return !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
// If we already received the final offset for this stream, the peer won't need any additional flow control credit.
|
||||
if c.receivedFinalOffset {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
|
||||
c.mutex.Lock()
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
161
vendor/github.com/quic-go/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
161
vendor/github.com/quic-go/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
|
@ -0,0 +1,161 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
keyLabel = hkdfLabelKeyV2
|
||||
ivLabel = hkdfLabelIVV2
|
||||
}
|
||||
key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen)
|
||||
iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen())
|
||||
return suite.AEAD(key, iv)
|
||||
}
|
||||
|
||||
type longHeaderSealer struct {
|
||||
aead cipher.AEAD
|
||||
headerProtector headerProtector
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var _ LongHeaderSealer = &longHeaderSealer{}
|
||||
|
||||
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||
return &longHeaderSealer{
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return s.aead.Seal(dst, s.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) Overhead() int {
|
||||
return s.aead.Overhead()
|
||||
}
|
||||
|
||||
type longHeaderOpener struct {
|
||||
aead cipher.AEAD
|
||||
headerProtector headerProtector
|
||||
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||
|
||||
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||
return &longHeaderOpener{
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
|
||||
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
|
||||
if err == nil {
|
||||
o.highestRcvdPN = utils.Max(o.highestRcvdPN, pn)
|
||||
} else {
|
||||
err = ErrDecryptionFailed
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
|
||||
type handshakeSealer struct {
|
||||
LongHeaderSealer
|
||||
|
||||
dropInitialKeys func()
|
||||
dropped bool
|
||||
}
|
||||
|
||||
func newHandshakeSealer(
|
||||
aead cipher.AEAD,
|
||||
headerProtector headerProtector,
|
||||
dropInitialKeys func(),
|
||||
perspective protocol.Perspective,
|
||||
) LongHeaderSealer {
|
||||
sealer := newLongHeaderSealer(aead, headerProtector)
|
||||
// The client drops Initial keys when sending the first Handshake packet.
|
||||
if perspective == protocol.PerspectiveServer {
|
||||
return sealer
|
||||
}
|
||||
return &handshakeSealer{
|
||||
LongHeaderSealer: sealer,
|
||||
dropInitialKeys: dropInitialKeys,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
data := s.LongHeaderSealer.Seal(dst, src, pn, ad)
|
||||
if !s.dropped {
|
||||
s.dropInitialKeys()
|
||||
s.dropped = true
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
type handshakeOpener struct {
|
||||
LongHeaderOpener
|
||||
|
||||
dropInitialKeys func()
|
||||
dropped bool
|
||||
}
|
||||
|
||||
func newHandshakeOpener(
|
||||
aead cipher.AEAD,
|
||||
headerProtector headerProtector,
|
||||
dropInitialKeys func(),
|
||||
perspective protocol.Perspective,
|
||||
) LongHeaderOpener {
|
||||
opener := newLongHeaderOpener(aead, headerProtector)
|
||||
// The server drops Initial keys when first successfully processing a Handshake packet.
|
||||
if perspective == protocol.PerspectiveClient {
|
||||
return opener
|
||||
}
|
||||
return &handshakeOpener{
|
||||
LongHeaderOpener: opener,
|
||||
dropInitialKeys: dropInitialKeys,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad)
|
||||
if err == nil && !o.dropped {
|
||||
o.dropInitialKeys()
|
||||
o.dropped = true
|
||||
}
|
||||
return dec, err
|
||||
}
|
837
vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
837
vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
|
@ -0,0 +1,837 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// TLS unexpected_message alert
|
||||
const alertUnexpectedMessage uint8 = 10
|
||||
|
||||
type messageType uint8
|
||||
|
||||
// TLS handshake message types.
|
||||
const (
|
||||
typeClientHello messageType = 1
|
||||
typeServerHello messageType = 2
|
||||
typeNewSessionTicket messageType = 4
|
||||
typeEncryptedExtensions messageType = 8
|
||||
typeCertificate messageType = 11
|
||||
typeCertificateRequest messageType = 13
|
||||
typeCertificateVerify messageType = 15
|
||||
typeFinished messageType = 20
|
||||
)
|
||||
|
||||
func (m messageType) String() string {
|
||||
switch m {
|
||||
case typeClientHello:
|
||||
return "ClientHello"
|
||||
case typeServerHello:
|
||||
return "ServerHello"
|
||||
case typeNewSessionTicket:
|
||||
return "NewSessionTicket"
|
||||
case typeEncryptedExtensions:
|
||||
return "EncryptedExtensions"
|
||||
case typeCertificate:
|
||||
return "Certificate"
|
||||
case typeCertificateRequest:
|
||||
return "CertificateRequest"
|
||||
case typeCertificateVerify:
|
||||
return "CertificateVerify"
|
||||
case typeFinished:
|
||||
return "Finished"
|
||||
default:
|
||||
return fmt.Sprintf("unknown message type: %d", m)
|
||||
}
|
||||
}
|
||||
|
||||
const clientSessionStateRevision = 3
|
||||
|
||||
type conn struct {
|
||||
localAddr, remoteAddr net.Addr
|
||||
version protocol.VersionNumber
|
||||
}
|
||||
|
||||
var _ ConnWithVersion = &conn{}
|
||||
|
||||
func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion {
|
||||
return &conn{
|
||||
localAddr: local,
|
||||
remoteAddr: remote,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
var _ net.Conn = &conn{}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version }
|
||||
|
||||
type cryptoSetup struct {
|
||||
tlsConf *tls.Config
|
||||
extraConf *qtls.ExtraConfig
|
||||
conn *qtls.Conn
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
||||
messageChan chan []byte
|
||||
isReadingHandshakeMessage chan struct{}
|
||||
readFirstHandshakeMessage bool
|
||||
|
||||
ourParams *wire.TransportParameters
|
||||
peerParams *wire.TransportParameters
|
||||
paramsChan <-chan []byte
|
||||
|
||||
runner handshakeRunner
|
||||
|
||||
alertChan chan uint8
|
||||
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
|
||||
handshakeDone chan struct{}
|
||||
// is closed when Close() is called
|
||||
closeChan chan struct{}
|
||||
|
||||
zeroRTTParameters *wire.TransportParameters
|
||||
clientHelloWritten bool
|
||||
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
|
||||
zeroRTTParametersChan chan<- *wire.TransportParameters
|
||||
allow0RTT func() bool
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
tracer logging.ConnectionTracer
|
||||
logger utils.Logger
|
||||
|
||||
perspective protocol.Perspective
|
||||
|
||||
mutex sync.Mutex // protects all members below
|
||||
|
||||
handshakeCompleteTime time.Time
|
||||
|
||||
readEncLevel protocol.EncryptionLevel
|
||||
writeEncLevel protocol.EncryptionLevel
|
||||
|
||||
zeroRTTOpener LongHeaderOpener // only set for the server
|
||||
zeroRTTSealer LongHeaderSealer // only set for the client
|
||||
|
||||
initialStream io.Writer
|
||||
initialOpener LongHeaderOpener
|
||||
initialSealer LongHeaderSealer
|
||||
|
||||
handshakeStream io.Writer
|
||||
handshakeOpener LongHeaderOpener
|
||||
handshakeSealer LongHeaderSealer
|
||||
|
||||
aead *updatableAEAD
|
||||
has1RTTSealer bool
|
||||
has1RTTOpener bool
|
||||
}
|
||||
|
||||
var (
|
||||
_ qtls.RecordLayer = &cryptoSetup{}
|
||||
_ CryptoSetup = &cryptoSetup{}
|
||||
)
|
||||
|
||||
// NewCryptoSetupClient creates a new crypto setup for the client
|
||||
func NewCryptoSetupClient(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
localAddr net.Addr,
|
||||
remoteAddr net.Addr,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
enable0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
|
||||
cs, clientHelloWritten := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
connID,
|
||||
tp,
|
||||
runner,
|
||||
tlsConf,
|
||||
enable0RTT,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
protocol.PerspectiveClient,
|
||||
version,
|
||||
)
|
||||
cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
|
||||
return cs, clientHelloWritten
|
||||
}
|
||||
|
||||
// NewCryptoSetupServer creates a new crypto setup for the server
|
||||
func NewCryptoSetupServer(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
localAddr net.Addr,
|
||||
remoteAddr net.Addr,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
allow0RTT func() bool,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) CryptoSetup {
|
||||
cs, _ := newCryptoSetup(
|
||||
initialStream,
|
||||
handshakeStream,
|
||||
connID,
|
||||
tp,
|
||||
runner,
|
||||
tlsConf,
|
||||
allow0RTT != nil,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
protocol.PerspectiveServer,
|
||||
version,
|
||||
)
|
||||
cs.allow0RTT = allow0RTT
|
||||
cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
|
||||
return cs
|
||||
}
|
||||
|
||||
func newCryptoSetup(
|
||||
initialStream io.Writer,
|
||||
handshakeStream io.Writer,
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
runner handshakeRunner,
|
||||
tlsConf *tls.Config,
|
||||
enable0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
version protocol.VersionNumber,
|
||||
) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
|
||||
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
|
||||
if tracer != nil {
|
||||
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
|
||||
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
|
||||
}
|
||||
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version)
|
||||
zeroRTTParametersChan := make(chan *wire.TransportParameters, 1)
|
||||
cs := &cryptoSetup{
|
||||
tlsConf: tlsConf,
|
||||
initialStream: initialStream,
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
handshakeStream: handshakeStream,
|
||||
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
|
||||
readEncLevel: protocol.EncryptionInitial,
|
||||
writeEncLevel: protocol.EncryptionInitial,
|
||||
runner: runner,
|
||||
ourParams: tp,
|
||||
paramsChan: extHandler.TransportParameters(),
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
handshakeDone: make(chan struct{}),
|
||||
alertChan: make(chan uint8),
|
||||
clientHelloWrittenChan: make(chan struct{}),
|
||||
zeroRTTParametersChan: zeroRTTParametersChan,
|
||||
messageChan: make(chan []byte, 100),
|
||||
isReadingHandshakeMessage: make(chan struct{}),
|
||||
closeChan: make(chan struct{}),
|
||||
version: version,
|
||||
}
|
||||
var maxEarlyData uint32
|
||||
if enable0RTT {
|
||||
maxEarlyData = math.MaxUint32
|
||||
}
|
||||
cs.extraConf = &qtls.ExtraConfig{
|
||||
GetExtensions: extHandler.GetExtensions,
|
||||
ReceivedExtensions: extHandler.ReceivedExtensions,
|
||||
AlternativeRecordLayer: cs,
|
||||
EnforceNextProtoSelection: true,
|
||||
MaxEarlyData: maxEarlyData,
|
||||
Accept0RTT: cs.accept0RTT,
|
||||
Rejected0RTT: cs.rejected0RTT,
|
||||
Enable0RTT: enable0RTT,
|
||||
GetAppDataForSessionState: cs.marshalDataForSessionState,
|
||||
SetAppDataFromSessionState: cs.handleDataFromSessionState,
|
||||
}
|
||||
return cs, zeroRTTParametersChan
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
|
||||
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
|
||||
h.initialSealer = initialSealer
|
||||
h.initialOpener = initialOpener
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
|
||||
return h.aead.SetLargestAcked(pn)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) RunHandshake() {
|
||||
// Handle errors that might occur when HandleData() is called.
|
||||
handshakeComplete := make(chan struct{})
|
||||
handshakeErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(h.handshakeDone)
|
||||
if err := h.conn.Handshake(); err != nil {
|
||||
handshakeErrChan <- err
|
||||
return
|
||||
}
|
||||
close(handshakeComplete)
|
||||
}()
|
||||
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
select {
|
||||
case err := <-handshakeErrChan:
|
||||
h.onError(0, err.Error())
|
||||
return
|
||||
case <-h.clientHelloWrittenChan:
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-handshakeComplete: // return when the handshake is done
|
||||
h.mutex.Lock()
|
||||
h.handshakeCompleteTime = time.Now()
|
||||
h.mutex.Unlock()
|
||||
h.runner.OnHandshakeComplete()
|
||||
case <-h.closeChan:
|
||||
// wait until the Handshake() go routine has returned
|
||||
<-h.handshakeDone
|
||||
case alert := <-h.alertChan:
|
||||
handshakeErr := <-handshakeErrChan
|
||||
h.onError(alert, handshakeErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) onError(alert uint8, message string) {
|
||||
var err error
|
||||
if alert == 0 {
|
||||
err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message}
|
||||
} else {
|
||||
err = qerr.NewLocalCryptoError(alert, message)
|
||||
}
|
||||
h.runner.OnError(err)
|
||||
}
|
||||
|
||||
// Close closes the crypto setup.
|
||||
// It aborts the handshake, if it is still running.
|
||||
// It must only be called once.
|
||||
func (h *cryptoSetup) Close() error {
|
||||
close(h.closeChan)
|
||||
// wait until qtls.Handshake() actually returned
|
||||
<-h.handshakeDone
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMessage handles a TLS handshake message.
|
||||
// It is called by the crypto streams when a new message is available.
|
||||
// It returns if it is done with messages on the same encryption level.
|
||||
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
|
||||
msgType := messageType(data[0])
|
||||
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
|
||||
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
|
||||
h.onError(alertUnexpectedMessage, err.Error())
|
||||
return false
|
||||
}
|
||||
h.messageChan <- data
|
||||
if encLevel == protocol.Encryption1RTT {
|
||||
h.handlePostHandshakeMessage()
|
||||
return false
|
||||
}
|
||||
readLoop:
|
||||
for {
|
||||
select {
|
||||
case data := <-h.paramsChan:
|
||||
if data == nil {
|
||||
h.onError(0x6d, "missing quic_transport_parameters extension")
|
||||
} else {
|
||||
h.handleTransportParameters(data)
|
||||
}
|
||||
case <-h.isReadingHandshakeMessage:
|
||||
break readLoop
|
||||
case <-h.handshakeDone:
|
||||
break readLoop
|
||||
case <-h.closeChan:
|
||||
break readLoop
|
||||
}
|
||||
}
|
||||
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
|
||||
// but only if a handshake opener and sealer was created.
|
||||
// Otherwise, a HelloRetryRequest was performed.
|
||||
// We're done with the Handshake encryption level after processing the Finished message.
|
||||
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
|
||||
msgType == typeFinished
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
|
||||
var expected protocol.EncryptionLevel
|
||||
switch msgType {
|
||||
case typeClientHello,
|
||||
typeServerHello:
|
||||
expected = protocol.EncryptionInitial
|
||||
case typeEncryptedExtensions,
|
||||
typeCertificate,
|
||||
typeCertificateRequest,
|
||||
typeCertificateVerify,
|
||||
typeFinished:
|
||||
expected = protocol.EncryptionHandshake
|
||||
case typeNewSessionTicket:
|
||||
expected = protocol.Encryption1RTT
|
||||
default:
|
||||
return fmt.Errorf("unexpected handshake message: %d", msgType)
|
||||
}
|
||||
if encLevel != expected {
|
||||
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleTransportParameters(data []byte) {
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
||||
h.runner.OnError(&qerr.TransportError{
|
||||
ErrorCode: qerr.TransportParameterError,
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
}
|
||||
h.peerParams = &tp
|
||||
h.runner.OnReceivedParams(h.peerParams)
|
||||
}
|
||||
|
||||
// must be called after receiving the transport parameters
|
||||
func (h *cryptoSetup) marshalDataForSessionState() []byte {
|
||||
b := make([]byte, 0, 256)
|
||||
b = quicvarint.Append(b, clientSessionStateRevision)
|
||||
b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds()))
|
||||
return h.peerParams.MarshalForSessionTicket(b)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
|
||||
tp, err := h.handleDataFromSessionStateImpl(data)
|
||||
if err != nil {
|
||||
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
|
||||
return
|
||||
}
|
||||
h.zeroRTTParameters = tp
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
|
||||
r := bytes.NewReader(data)
|
||||
ver, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ver != clientSessionStateRevision {
|
||||
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
|
||||
}
|
||||
rtt, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tp, nil
|
||||
}
|
||||
|
||||
// only valid for the server
|
||||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
var appData []byte
|
||||
// Save transport parameters to the session ticket if we're allowing 0-RTT.
|
||||
if h.extraConf.MaxEarlyData > 0 {
|
||||
appData = (&sessionTicket{
|
||||
Parameters: h.ourParams,
|
||||
RTT: h.rttStats.SmoothedRTT(),
|
||||
}).Marshal()
|
||||
}
|
||||
return h.conn.GetSessionTicket(appData)
|
||||
}
|
||||
|
||||
// accept0RTT is called for the server when receiving the client's session ticket.
|
||||
// It decides whether to accept 0-RTT.
|
||||
func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
|
||||
var t sessionTicket
|
||||
if err := t.Unmarshal(sessionTicketData); err != nil {
|
||||
h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
valid := h.ourParams.ValidFor0RTT(t.Parameters)
|
||||
if !valid {
|
||||
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
|
||||
return false
|
||||
}
|
||||
if !h.allow0RTT() {
|
||||
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
|
||||
return false
|
||||
}
|
||||
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
|
||||
h.rttStats.SetInitialRTT(t.RTT)
|
||||
return true
|
||||
}
|
||||
|
||||
// rejected0RTT is called for the client when the server rejects 0-RTT.
|
||||
func (h *cryptoSetup) rejected0RTT() {
|
||||
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
|
||||
|
||||
h.mutex.Lock()
|
||||
had0RTTKeys := h.zeroRTTSealer != nil
|
||||
h.zeroRTTSealer = nil
|
||||
h.mutex.Unlock()
|
||||
|
||||
if had0RTTKeys {
|
||||
h.runner.DropKeys(protocol.Encryption0RTT)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handlePostHandshakeMessage() {
|
||||
// make sure the handshake has already completed
|
||||
<-h.handshakeDone
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
// h.alertChan is an unbuffered channel.
|
||||
// If an error occurs during conn.HandlePostHandshakeMessage,
|
||||
// it will be sent on this channel.
|
||||
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
|
||||
alertChan := make(chan uint8, 1)
|
||||
go func() {
|
||||
<-h.isReadingHandshakeMessage
|
||||
select {
|
||||
case alert := <-h.alertChan:
|
||||
alertChan <- alert
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
if err := h.conn.HandlePostHandshakeMessage(); err != nil {
|
||||
select {
|
||||
case <-h.closeChan:
|
||||
case alert := <-alertChan:
|
||||
h.onError(alert, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadHandshakeMessage is called by TLS.
|
||||
// It blocks until a new handshake message is available.
|
||||
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
|
||||
if !h.readFirstHandshakeMessage {
|
||||
h.readFirstHandshakeMessage = true
|
||||
} else {
|
||||
select {
|
||||
case h.isReadingHandshakeMessage <- struct{}{}:
|
||||
case <-h.closeChan:
|
||||
return nil, errors.New("error while handling the handshake message")
|
||||
}
|
||||
}
|
||||
select {
|
||||
case msg := <-h.messageChan:
|
||||
return msg, nil
|
||||
case <-h.closeChan:
|
||||
return nil, errors.New("error while handling the handshake message")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
h.mutex.Lock()
|
||||
switch encLevel {
|
||||
case qtls.Encryption0RTT:
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
panic("Received 0-RTT read key for the client")
|
||||
}
|
||||
h.zeroRTTOpener = newLongHeaderOpener(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
h.mutex.Unlock()
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite())
|
||||
}
|
||||
return
|
||||
case qtls.EncryptionHandshake:
|
||||
h.readEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeOpener = newHandshakeOpener(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
h.dropInitialKeys,
|
||||
h.perspective,
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case qtls.EncryptionApplication:
|
||||
h.readEncLevel = protocol.Encryption1RTT
|
||||
h.aead.SetReadKey(suite, trafficSecret)
|
||||
h.has1RTTOpener = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
default:
|
||||
panic("unexpected read encryption level")
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
h.mutex.Lock()
|
||||
switch encLevel {
|
||||
case qtls.Encryption0RTT:
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
panic("Received 0-RTT write key for the server")
|
||||
}
|
||||
h.zeroRTTSealer = newLongHeaderSealer(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
h.mutex.Unlock()
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
|
||||
}
|
||||
return
|
||||
case qtls.EncryptionHandshake:
|
||||
h.writeEncLevel = protocol.EncryptionHandshake
|
||||
h.handshakeSealer = newHandshakeSealer(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
h.dropInitialKeys,
|
||||
h.perspective,
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case qtls.EncryptionApplication:
|
||||
h.writeEncLevel = protocol.Encryption1RTT
|
||||
h.aead.SetWriteKey(suite, trafficSecret)
|
||||
h.has1RTTSealer = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.zeroRTTSealer != nil {
|
||||
h.zeroRTTSealer = nil
|
||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||
if h.tracer != nil {
|
||||
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
|
||||
}
|
||||
}
|
||||
default:
|
||||
panic("unexpected write encryption level")
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
if h.tracer != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteRecord is called when TLS writes data
|
||||
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
//nolint:exhaustive // LS records can only be written for Initial and Handshake.
|
||||
switch h.writeEncLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
// assume that the first WriteRecord call contains the ClientHello
|
||||
n, err := h.initialStream.Write(p)
|
||||
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
|
||||
h.clientHelloWritten = true
|
||||
close(h.clientHelloWrittenChan)
|
||||
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
|
||||
h.logger.Debugf("Doing 0-RTT.")
|
||||
h.zeroRTTParametersChan <- h.zeroRTTParameters
|
||||
} else {
|
||||
h.logger.Debugf("Not doing 0-RTT.")
|
||||
h.zeroRTTParametersChan <- nil
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
case protocol.EncryptionHandshake:
|
||||
return h.handshakeStream.Write(p)
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SendAlert(alert uint8) {
|
||||
select {
|
||||
case h.alertChan <- alert:
|
||||
case <-h.closeChan:
|
||||
// no need to send an alert when we've already closed
|
||||
}
|
||||
}
|
||||
|
||||
// used a callback in the handshakeSealer and handshakeOpener
|
||||
func (h *cryptoSetup) dropInitialKeys() {
|
||||
h.mutex.Lock()
|
||||
h.initialOpener = nil
|
||||
h.initialSealer = nil
|
||||
h.mutex.Unlock()
|
||||
h.runner.DropKeys(protocol.EncryptionInitial)
|
||||
h.logger.Debugf("Dropping Initial keys.")
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetHandshakeConfirmed() {
|
||||
h.aead.SetHandshakeConfirmed()
|
||||
// drop Handshake keys
|
||||
var dropped bool
|
||||
h.mutex.Lock()
|
||||
if h.handshakeOpener != nil {
|
||||
h.handshakeOpener = nil
|
||||
h.handshakeSealer = nil
|
||||
dropped = true
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
if dropped {
|
||||
h.runner.DropKeys(protocol.EncryptionHandshake)
|
||||
h.logger.Debugf("Dropping Handshake keys.")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.initialSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.initialSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.zeroRTTSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.handshakeSealer == nil {
|
||||
if h.initialSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.handshakeSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if !h.has1RTTSealer {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.aead, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.initialOpener == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.initialOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
// if the initial opener is also not available, the keys were already dropped
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.zeroRTTOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.handshakeOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
// if the initial opener is also not available, the keys were already dropped
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.handshakeOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
|
||||
h.zeroRTTOpener = nil
|
||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||
if h.tracer != nil {
|
||||
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
|
||||
}
|
||||
}
|
||||
|
||||
if !h.has1RTTOpener {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.aead, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||
return qtls.GetConnectionState(h.conn)
|
||||
}
|
136
vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go
generated
vendored
Normal file
136
vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go
generated
vendored
Normal file
|
@ -0,0 +1,136 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
type headerProtector interface {
|
||||
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
}
|
||||
|
||||
func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
|
||||
if v == protocol.Version2 {
|
||||
return "quicv2 hp"
|
||||
}
|
||||
return "quic hp"
|
||||
}
|
||||
|
||||
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
|
||||
hkdfLabel := hkdfHeaderProtectionLabel(v)
|
||||
switch suite.ID {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
|
||||
default:
|
||||
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
|
||||
}
|
||||
}
|
||||
|
||||
type aesHeaderProtector struct {
|
||||
mask []byte
|
||||
block cipher.Block
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &aesHeaderProtector{}
|
||||
|
||||
func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
|
||||
block, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
return &aesHeaderProtector{
|
||||
block: block,
|
||||
mask: make([]byte, block.BlockSize()),
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) != len(p.mask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
p.block.Encrypt(p.mask, sample)
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
type chachaHeaderProtector struct {
|
||||
mask [5]byte
|
||||
|
||||
key [32]byte
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &chachaHeaderProtector{}
|
||||
|
||||
func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
|
||||
|
||||
p := &chachaHeaderProtector{
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
copy(p.key[:], hpKey)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) != 16 {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
p.mask[i] = 0
|
||||
}
|
||||
cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4]))
|
||||
cipher.XORKeyStream(p.mask[:], p.mask[:])
|
||||
p.applyMask(firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) {
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
29
vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go
generated
vendored
Normal file
29
vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go
generated
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// hkdfExpandLabel HKDF expands a label.
|
||||
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
|
||||
// hkdfExpandLabel in the standard library.
|
||||
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
|
||||
b := make([]byte, 3, 3+6+len(label)+1+len(context))
|
||||
binary.BigEndian.PutUint16(b, uint16(length))
|
||||
b[2] = uint8(6 + len(label))
|
||||
b = append(b, []byte("tls13 ")...)
|
||||
b = append(b, []byte(label)...)
|
||||
b = b[:3+6+len(label)+1]
|
||||
b[3+6+len(label)] = uint8(len(context))
|
||||
b = append(b, context...)
|
||||
|
||||
out := make([]byte, length)
|
||||
n, err := hkdf.Expand(hash.New, secret, b).Read(out)
|
||||
if err != nil || n != length {
|
||||
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
|
||||
}
|
||||
return out
|
||||
}
|
81
vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go
generated
vendored
Normal file
81
vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go
generated
vendored
Normal file
|
@ -0,0 +1,81 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
var (
|
||||
quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
|
||||
quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
|
||||
quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9}
|
||||
)
|
||||
|
||||
const (
|
||||
hkdfLabelKeyV1 = "quic key"
|
||||
hkdfLabelKeyV2 = "quicv2 key"
|
||||
hkdfLabelIVV1 = "quic iv"
|
||||
hkdfLabelIVV2 = "quicv2 iv"
|
||||
)
|
||||
|
||||
func getSalt(v protocol.VersionNumber) []byte {
|
||||
if v == protocol.Version2 {
|
||||
return quicSaltV2
|
||||
}
|
||||
if v == protocol.Version1 {
|
||||
return quicSaltV1
|
||||
}
|
||||
return quicSaltOld
|
||||
}
|
||||
|
||||
var initialSuite = &qtls.CipherSuiteTLS13{
|
||||
ID: tls.TLS_AES_128_GCM_SHA256,
|
||||
KeyLen: 16,
|
||||
AEAD: qtls.AEADAESGCMTLS13,
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
|
||||
clientSecret, serverSecret := computeSecrets(connID, v)
|
||||
var mySecret, otherSecret []byte
|
||||
if pers == protocol.PerspectiveClient {
|
||||
mySecret = clientSecret
|
||||
otherSecret = serverSecret
|
||||
} else {
|
||||
mySecret = serverSecret
|
||||
otherSecret = clientSecret
|
||||
}
|
||||
myKey, myIV := computeInitialKeyAndIV(mySecret, v)
|
||||
otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v)
|
||||
|
||||
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
|
||||
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
|
||||
|
||||
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)),
|
||||
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
|
||||
}
|
||||
|
||||
func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) {
|
||||
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
|
||||
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
|
||||
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
|
||||
return
|
||||
}
|
||||
|
||||
func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
keyLabel = hkdfLabelKeyV2
|
||||
ivLabel = hkdfLabelIVV2
|
||||
}
|
||||
key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
|
||||
iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
|
||||
return
|
||||
}
|
102
vendor/github.com/quic-go/quic-go/internal/handshake/interface.go
generated
vendored
Normal file
102
vendor/github.com/quic-go/quic-go/internal/handshake/interface.go
generated
vendored
Normal file
|
@ -0,0 +1,102 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level,
|
||||
// but the corresponding opener has not yet been initialized
|
||||
// This can happen when packets arrive out of order.
|
||||
ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available")
|
||||
// ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level,
|
||||
// but the corresponding keys have already been dropped.
|
||||
ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped")
|
||||
// ErrDecryptionFailed is returned when the AEAD fails to open the packet.
|
||||
ErrDecryptionFailed = errors.New("decryption failed")
|
||||
)
|
||||
|
||||
// ConnectionState contains information about the state of the connection.
|
||||
type ConnectionState = qtls.ConnectionState
|
||||
|
||||
type headerDecryptor interface {
|
||||
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
}
|
||||
|
||||
// LongHeaderOpener opens a long header packet
|
||||
type LongHeaderOpener interface {
|
||||
headerDecryptor
|
||||
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
|
||||
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// ShortHeaderOpener opens a short header packet
|
||||
type ShortHeaderOpener interface {
|
||||
headerDecryptor
|
||||
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
|
||||
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// LongHeaderSealer seals a long header packet
|
||||
type LongHeaderSealer interface {
|
||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
Overhead() int
|
||||
}
|
||||
|
||||
// ShortHeaderSealer seals a short header packet
|
||||
type ShortHeaderSealer interface {
|
||||
LongHeaderSealer
|
||||
KeyPhase() protocol.KeyPhaseBit
|
||||
}
|
||||
|
||||
// A tlsExtensionHandler sends and received the QUIC TLS extension.
|
||||
type tlsExtensionHandler interface {
|
||||
GetExtensions(msgType uint8) []qtls.Extension
|
||||
ReceivedExtensions(msgType uint8, exts []qtls.Extension)
|
||||
TransportParameters() <-chan []byte
|
||||
}
|
||||
|
||||
type handshakeRunner interface {
|
||||
OnReceivedParams(*wire.TransportParameters)
|
||||
OnHandshakeComplete()
|
||||
OnError(error)
|
||||
DropKeys(protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||
type CryptoSetup interface {
|
||||
RunHandshake()
|
||||
io.Closer
|
||||
ChangeConnectionID(protocol.ConnectionID)
|
||||
GetSessionTicket() ([]byte, error)
|
||||
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||
SetHandshakeConfirmed()
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
GetInitialOpener() (LongHeaderOpener, error)
|
||||
GetHandshakeOpener() (LongHeaderOpener, error)
|
||||
Get0RTTOpener() (LongHeaderOpener, error)
|
||||
Get1RTTOpener() (ShortHeaderOpener, error)
|
||||
|
||||
GetInitialSealer() (LongHeaderSealer, error)
|
||||
GetHandshakeSealer() (LongHeaderSealer, error)
|
||||
Get0RTTSealer() (LongHeaderSealer, error)
|
||||
Get1RTTSealer() (ShortHeaderSealer, error)
|
||||
}
|
||||
|
||||
// ConnWithVersion is the connection used in the ClientHelloInfo.
|
||||
// It can be used to determine the QUIC version in use.
|
||||
type ConnWithVersion interface {
|
||||
net.Conn
|
||||
GetQUICVersion() protocol.VersionNumber
|
||||
}
|
3
vendor/github.com/quic-go/quic-go/internal/handshake/mockgen.go
generated
vendored
Normal file
3
vendor/github.com/quic-go/quic-go/internal/handshake/mockgen.go
generated
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
package handshake
|
||||
|
||||
//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/quic-go/quic-go/internal/handshake handshakeRunner"
|
70
vendor/github.com/quic-go/quic-go/internal/handshake/retry.go
generated
vendored
Normal file
70
vendor/github.com/quic-go/quic-go/internal/handshake/retry.go
generated
vendored
Normal file
|
@ -0,0 +1,70 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
retryAEADdraft29 cipher.AEAD // used for QUIC draft versions up to 34
|
||||
retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000)
|
||||
retryAEADv2 cipher.AEAD // used for QUIC v2
|
||||
)
|
||||
|
||||
func init() {
|
||||
retryAEADdraft29 = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1})
|
||||
retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
|
||||
retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
|
||||
}
|
||||
|
||||
func initAEAD(key [16]byte) cipher.AEAD {
|
||||
aes, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return aead
|
||||
}
|
||||
|
||||
var (
|
||||
retryBuf bytes.Buffer
|
||||
retryMutex sync.Mutex
|
||||
retryNonceDraft29 = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c}
|
||||
retryNonceV1 = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
|
||||
retryNonceV2 = [12]byte{0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a}
|
||||
)
|
||||
|
||||
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
|
||||
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte {
|
||||
retryMutex.Lock()
|
||||
defer retryMutex.Unlock()
|
||||
|
||||
retryBuf.WriteByte(uint8(origDestConnID.Len()))
|
||||
retryBuf.Write(origDestConnID.Bytes())
|
||||
retryBuf.Write(retry)
|
||||
defer retryBuf.Reset()
|
||||
|
||||
var tag [16]byte
|
||||
var sealed []byte
|
||||
//nolint:exhaustive // These are all the versions we support
|
||||
switch version {
|
||||
case protocol.Version1:
|
||||
sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes())
|
||||
case protocol.Version2:
|
||||
sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes())
|
||||
default:
|
||||
sealed = retryAEADdraft29.Seal(tag[:0], retryNonceDraft29[:], nil, retryBuf.Bytes())
|
||||
}
|
||||
if len(sealed) != 16 {
|
||||
panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed)))
|
||||
}
|
||||
return &tag
|
||||
}
|
47
vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go
generated
vendored
Normal file
47
vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go
generated
vendored
Normal file
|
@ -0,0 +1,47 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
const sessionTicketRevision = 2
|
||||
|
||||
type sessionTicket struct {
|
||||
Parameters *wire.TransportParameters
|
||||
RTT time.Duration // to be encoded in mus
|
||||
}
|
||||
|
||||
func (t *sessionTicket) Marshal() []byte {
|
||||
b := make([]byte, 0, 256)
|
||||
b = quicvarint.Append(b, sessionTicketRevision)
|
||||
b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
|
||||
return t.Parameters.MarshalForSessionTicket(b)
|
||||
}
|
||||
|
||||
func (t *sessionTicket) Unmarshal(b []byte) error {
|
||||
r := bytes.NewReader(b)
|
||||
rev, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return errors.New("failed to read session ticket revision")
|
||||
}
|
||||
if rev != sessionTicketRevision {
|
||||
return fmt.Errorf("unknown session ticket revision: %d", rev)
|
||||
}
|
||||
rtt, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return errors.New("failed to read RTT")
|
||||
}
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
||||
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
|
||||
}
|
||||
t.Parameters = &tp
|
||||
t.RTT = time.Duration(rtt) * time.Microsecond
|
||||
return nil
|
||||
}
|
68
vendor/github.com/quic-go/quic-go/internal/handshake/tls_extension_handler.go
generated
vendored
Normal file
68
vendor/github.com/quic-go/quic-go/internal/handshake/tls_extension_handler.go
generated
vendored
Normal file
|
@ -0,0 +1,68 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
const (
|
||||
quicTLSExtensionTypeOldDrafts = 0xffa5
|
||||
quicTLSExtensionType = 0x39
|
||||
)
|
||||
|
||||
type extensionHandler struct {
|
||||
ourParams []byte
|
||||
paramsChan chan []byte
|
||||
|
||||
extensionType uint16
|
||||
|
||||
perspective protocol.Perspective
|
||||
}
|
||||
|
||||
var _ tlsExtensionHandler = &extensionHandler{}
|
||||
|
||||
// newExtensionHandler creates a new extension handler
|
||||
func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler {
|
||||
et := uint16(quicTLSExtensionType)
|
||||
if v != protocol.Version1 {
|
||||
et = quicTLSExtensionTypeOldDrafts
|
||||
}
|
||||
return &extensionHandler{
|
||||
ourParams: params,
|
||||
paramsChan: make(chan []byte),
|
||||
perspective: pers,
|
||||
extensionType: et,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
|
||||
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) ||
|
||||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) {
|
||||
return nil
|
||||
}
|
||||
return []qtls.Extension{{
|
||||
Type: h.extensionType,
|
||||
Data: h.ourParams,
|
||||
}}
|
||||
}
|
||||
|
||||
func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
|
||||
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) ||
|
||||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) {
|
||||
return
|
||||
}
|
||||
|
||||
var data []byte
|
||||
for _, ext := range exts {
|
||||
if ext.Type == h.extensionType {
|
||||
data = ext.Data
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
h.paramsChan <- data
|
||||
}
|
||||
|
||||
func (h *extensionHandler) TransportParameters() <-chan []byte {
|
||||
return h.paramsChan
|
||||
}
|
127
vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go
generated
vendored
Normal file
127
vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go
generated
vendored
Normal file
|
@ -0,0 +1,127 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenPrefixIP byte = iota
|
||||
tokenPrefixString
|
||||
)
|
||||
|
||||
// A Token is derived from the client address and can be used to verify the ownership of this address.
|
||||
type Token struct {
|
||||
IsRetryToken bool
|
||||
SentTime time.Time
|
||||
encodedRemoteAddr []byte
|
||||
// only set for retry tokens
|
||||
OriginalDestConnectionID protocol.ConnectionID
|
||||
RetrySrcConnectionID protocol.ConnectionID
|
||||
}
|
||||
|
||||
// ValidateRemoteAddr validates the address, but does not check expiration
|
||||
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
|
||||
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
|
||||
}
|
||||
|
||||
// token is the struct that is used for ASN1 serialization and deserialization
|
||||
type token struct {
|
||||
IsRetryToken bool
|
||||
RemoteAddr []byte
|
||||
Timestamp int64
|
||||
OriginalDestConnectionID []byte
|
||||
RetrySrcConnectionID []byte
|
||||
}
|
||||
|
||||
// A TokenGenerator generates tokens
|
||||
type TokenGenerator struct {
|
||||
tokenProtector tokenProtector
|
||||
}
|
||||
|
||||
// NewTokenGenerator initializes a new TookenGenerator
|
||||
func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) {
|
||||
tokenProtector, err := newTokenProtector(rand)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenGenerator{
|
||||
tokenProtector: tokenProtector,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewRetryToken generates a new token for a Retry for a given source address
|
||||
func (g *TokenGenerator) NewRetryToken(
|
||||
raddr net.Addr,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
retrySrcConnID protocol.ConnectionID,
|
||||
) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
IsRetryToken: true,
|
||||
RemoteAddr: encodeRemoteAddr(raddr),
|
||||
OriginalDestConnectionID: origDestConnID.Bytes(),
|
||||
RetrySrcConnectionID: retrySrcConnID.Bytes(),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.tokenProtector.NewToken(data)
|
||||
}
|
||||
|
||||
// NewToken generates a new token to be sent in a NEW_TOKEN frame
|
||||
func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
RemoteAddr: encodeRemoteAddr(raddr),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.tokenProtector.NewToken(data)
|
||||
}
|
||||
|
||||
// DecodeToken decodes a token
|
||||
func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
|
||||
// if the client didn't send any token, DecodeToken will be called with a nil-slice
|
||||
if len(encrypted) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := g.tokenProtector.DecodeToken(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := &token{}
|
||||
rest, err := asn1.Unmarshal(data, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||
}
|
||||
token := &Token{
|
||||
IsRetryToken: t.IsRetryToken,
|
||||
SentTime: time.Unix(0, t.Timestamp),
|
||||
encodedRemoteAddr: t.RemoteAddr,
|
||||
}
|
||||
if t.IsRetryToken {
|
||||
token.OriginalDestConnectionID = protocol.ParseConnectionID(t.OriginalDestConnectionID)
|
||||
token.RetrySrcConnectionID = protocol.ParseConnectionID(t.RetrySrcConnectionID)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the token
|
||||
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||
return append([]byte{tokenPrefixIP}, udpAddr.IP...)
|
||||
}
|
||||
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
|
||||
}
|
89
vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go
generated
vendored
Normal file
89
vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go
generated
vendored
Normal file
|
@ -0,0 +1,89 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// TokenProtector is used to create and verify a token
|
||||
type tokenProtector interface {
|
||||
// NewToken creates a new token
|
||||
NewToken([]byte) ([]byte, error)
|
||||
// DecodeToken decodes a token
|
||||
DecodeToken([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
const (
|
||||
tokenSecretSize = 32
|
||||
tokenNonceSize = 32
|
||||
)
|
||||
|
||||
// tokenProtector is used to create and verify a token
|
||||
type tokenProtectorImpl struct {
|
||||
rand io.Reader
|
||||
secret []byte
|
||||
}
|
||||
|
||||
// newTokenProtector creates a source for source address tokens
|
||||
func newTokenProtector(rand io.Reader) (tokenProtector, error) {
|
||||
secret := make([]byte, tokenSecretSize)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tokenProtectorImpl{
|
||||
rand: rand,
|
||||
secret: secret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewToken encodes data into a new token.
|
||||
func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
|
||||
nonce := make([]byte, tokenNonceSize)
|
||||
if _, err := s.rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
|
||||
}
|
||||
|
||||
// DecodeToken decodes a token.
|
||||
func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < tokenNonceSize {
|
||||
return nil, fmt.Errorf("token too short: %d", len(p))
|
||||
}
|
||||
nonce := p[:tokenNonceSize]
|
||||
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
|
||||
}
|
||||
|
||||
func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
|
||||
h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source"))
|
||||
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
|
||||
if _, err := io.ReadFull(h, key); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aeadNonce := make([]byte, 12)
|
||||
if _, err := io.ReadFull(h, aeadNonce); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return aead, aeadNonce, nil
|
||||
}
|
323
vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go
generated
vendored
Normal file
323
vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go
generated
vendored
Normal file
|
@ -0,0 +1,323 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
|
||||
// It's a package-level variable to allow modifying it for testing purposes.
|
||||
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
|
||||
|
||||
type updatableAEAD struct {
|
||||
suite *qtls.CipherSuiteTLS13
|
||||
|
||||
keyPhase protocol.KeyPhase
|
||||
largestAcked protocol.PacketNumber
|
||||
firstPacketNumber protocol.PacketNumber
|
||||
handshakeConfirmed bool
|
||||
|
||||
keyUpdateInterval uint64
|
||||
invalidPacketLimit uint64
|
||||
invalidPacketCount uint64
|
||||
|
||||
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
|
||||
prevRcvAEADExpiry time.Time
|
||||
prevRcvAEAD cipher.AEAD
|
||||
|
||||
firstRcvdWithCurrentKey protocol.PacketNumber
|
||||
firstSentWithCurrentKey protocol.PacketNumber
|
||||
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
||||
numRcvdWithCurrentKey uint64
|
||||
numSentWithCurrentKey uint64
|
||||
rcvAEAD cipher.AEAD
|
||||
sendAEAD cipher.AEAD
|
||||
// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
|
||||
aeadOverhead int
|
||||
|
||||
nextRcvAEAD cipher.AEAD
|
||||
nextSendAEAD cipher.AEAD
|
||||
nextRcvTrafficSecret []byte
|
||||
nextSendTrafficSecret []byte
|
||||
|
||||
headerDecrypter headerProtector
|
||||
headerEncrypter headerProtector
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
tracer logging.ConnectionTracer
|
||||
logger utils.Logger
|
||||
version protocol.VersionNumber
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var (
|
||||
_ ShortHeaderOpener = &updatableAEAD{}
|
||||
_ ShortHeaderSealer = &updatableAEAD{}
|
||||
)
|
||||
|
||||
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD {
|
||||
return &updatableAEAD{
|
||||
firstPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestAcked: protocol.InvalidPacketNumber,
|
||||
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
keyUpdateInterval: KeyUpdateInterval,
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) rollKeys() {
|
||||
if a.prevRcvAEAD != nil {
|
||||
a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
|
||||
if a.tracer != nil {
|
||||
a.tracer.DroppedKey(a.keyPhase - 1)
|
||||
}
|
||||
a.prevRcvAEADExpiry = time.Time{}
|
||||
}
|
||||
|
||||
a.keyPhase++
|
||||
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
||||
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
||||
a.numRcvdWithCurrentKey = 0
|
||||
a.numSentWithCurrentKey = 0
|
||||
a.prevRcvAEAD = a.rcvAEAD
|
||||
a.rcvAEAD = a.nextRcvAEAD
|
||||
a.sendAEAD = a.nextSendAEAD
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
|
||||
a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version)
|
||||
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
|
||||
d := 3 * a.rttStats.PTO(true)
|
||||
a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
|
||||
a.prevRcvAEADExpiry = now.Add(d)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
|
||||
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
|
||||
}
|
||||
|
||||
// For the client, this function is called before SetWriteKey.
|
||||
// For the server, this function is called after SetWriteKey.
|
||||
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
if a.suite == nil {
|
||||
a.setAEADParameters(a.rcvAEAD, suite)
|
||||
}
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
||||
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
// For the client, this function is called after SetReadKey.
|
||||
// For the server, this function is called before SetWriteKey.
|
||||
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
|
||||
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
if a.suite == nil {
|
||||
a.setAEADParameters(a.sendAEAD, suite)
|
||||
}
|
||||
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
||||
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) {
|
||||
a.nonceBuf = make([]byte, aead.NonceSize())
|
||||
a.aeadOverhead = aead.Overhead()
|
||||
a.suite = suite
|
||||
switch suite.ID {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
a.invalidPacketLimit = protocol.InvalidPacketLimitAES
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
|
||||
}
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
|
||||
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
||||
dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
|
||||
if err == ErrDecryptionFailed {
|
||||
a.invalidPacketCount++
|
||||
if a.invalidPacketCount >= a.invalidPacketLimit {
|
||||
return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
a.highestRcvdPN = utils.Max(a.highestRcvdPN, pn)
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
||||
if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
|
||||
a.prevRcvAEAD = nil
|
||||
a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
|
||||
a.prevRcvAEADExpiry = time.Time{}
|
||||
if a.tracer != nil {
|
||||
a.tracer.DroppedKey(a.keyPhase - 1)
|
||||
}
|
||||
}
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
if kp != a.keyPhase.Bit() {
|
||||
if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
||||
if a.prevRcvAEAD == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
// we updated the key, but the peer hasn't updated yet
|
||||
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
err = ErrDecryptionFailed
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
// try opening the packet with the next key phase
|
||||
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
return nil, ErrDecryptionFailed
|
||||
}
|
||||
// Opening succeeded. Check if the peer was allowed to update.
|
||||
if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.KeyUpdateError,
|
||||
ErrorMessage: "keys updated too quickly",
|
||||
}
|
||||
}
|
||||
a.rollKeys()
|
||||
a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
|
||||
// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
|
||||
// Start a timer to drop the previous key generation.
|
||||
a.startKeyDropTimer(rcvTime)
|
||||
if a.tracer != nil {
|
||||
a.tracer.UpdatedKey(a.keyPhase, true)
|
||||
}
|
||||
a.firstRcvdWithCurrentKey = pn
|
||||
return dec, err
|
||||
}
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
return dec, ErrDecryptionFailed
|
||||
}
|
||||
a.numRcvdWithCurrentKey++
|
||||
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
// We initiated the key updated, and now we received the first packet protected with the new key phase.
|
||||
// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
|
||||
if a.keyPhase > 0 {
|
||||
a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
|
||||
a.startKeyDropTimer(rcvTime)
|
||||
}
|
||||
a.firstRcvdWithCurrentKey = pn
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
a.firstSentWithCurrentKey = pn
|
||||
}
|
||||
if a.firstPacketNumber == protocol.InvalidPacketNumber {
|
||||
a.firstPacketNumber = pn
|
||||
}
|
||||
a.numSentWithCurrentKey++
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
|
||||
if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
||||
pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.KeyUpdateError,
|
||||
ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
|
||||
}
|
||||
}
|
||||
a.largestAcked = pn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) SetHandshakeConfirmed() {
|
||||
a.handshakeConfirmed = true
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) updateAllowed() bool {
|
||||
if !a.handshakeConfirmed {
|
||||
return false
|
||||
}
|
||||
// the first key update is allowed as soon as the handshake is confirmed
|
||||
return a.keyPhase == 0 ||
|
||||
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
|
||||
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
||||
a.largestAcked != protocol.InvalidPacketNumber &&
|
||||
a.largestAcked >= a.firstSentWithCurrentKey)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
||||
if !a.updateAllowed() {
|
||||
return false
|
||||
}
|
||||
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
|
||||
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
||||
return true
|
||||
}
|
||||
if a.numSentWithCurrentKey >= a.keyUpdateInterval {
|
||||
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
|
||||
if a.shouldInitiateKeyUpdate() {
|
||||
a.rollKeys()
|
||||
a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
|
||||
if a.tracer != nil {
|
||||
a.tracer.UpdatedKey(a.keyPhase, false)
|
||||
}
|
||||
}
|
||||
return a.keyPhase.Bit()
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Overhead() int {
|
||||
return a.aeadOverhead
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
|
||||
return a.firstPacketNumber
|
||||
}
|
50
vendor/github.com/quic-go/quic-go/internal/logutils/frame.go
generated
vendored
Normal file
50
vendor/github.com/quic-go/quic-go/internal/logutils/frame.go
generated
vendored
Normal file
|
@ -0,0 +1,50 @@
|
|||
package logutils
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// ConvertFrame converts a wire.Frame into a logging.Frame.
|
||||
// This makes it possible for external packages to access the frames.
|
||||
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
|
||||
func ConvertFrame(frame wire.Frame) logging.Frame {
|
||||
switch f := frame.(type) {
|
||||
case *wire.AckFrame:
|
||||
// We use a pool for ACK frames.
|
||||
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
|
||||
return ConvertAckFrame(f)
|
||||
case *wire.CryptoFrame:
|
||||
return &logging.CryptoFrame{
|
||||
Offset: f.Offset,
|
||||
Length: protocol.ByteCount(len(f.Data)),
|
||||
}
|
||||
case *wire.StreamFrame:
|
||||
return &logging.StreamFrame{
|
||||
StreamID: f.StreamID,
|
||||
Offset: f.Offset,
|
||||
Length: f.DataLen(),
|
||||
Fin: f.Fin,
|
||||
}
|
||||
case *wire.DatagramFrame:
|
||||
return &logging.DatagramFrame{
|
||||
Length: logging.ByteCount(len(f.Data)),
|
||||
}
|
||||
default:
|
||||
return logging.Frame(frame)
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertAckFrame(f *wire.AckFrame) *logging.AckFrame {
|
||||
ranges := make([]wire.AckRange, 0, len(f.AckRanges))
|
||||
ranges = append(ranges, f.AckRanges...)
|
||||
ack := &logging.AckFrame{
|
||||
AckRanges: ranges,
|
||||
DelayTime: f.DelayTime,
|
||||
ECNCE: f.ECNCE,
|
||||
ECT0: f.ECT0,
|
||||
ECT1: f.ECT1,
|
||||
}
|
||||
return ack
|
||||
}
|
116
vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go
generated
vendored
Normal file
116
vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go
generated
vendored
Normal file
|
@ -0,0 +1,116 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length")
|
||||
|
||||
// An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999.
|
||||
// Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1
|
||||
// restricts the length to 20 bytes.
|
||||
type ArbitraryLenConnectionID []byte
|
||||
|
||||
func (c ArbitraryLenConnectionID) Len() int {
|
||||
return len(c)
|
||||
}
|
||||
|
||||
func (c ArbitraryLenConnectionID) Bytes() []byte {
|
||||
return c
|
||||
}
|
||||
|
||||
func (c ArbitraryLenConnectionID) String() string {
|
||||
if c.Len() == 0 {
|
||||
return "(empty)"
|
||||
}
|
||||
return fmt.Sprintf("%x", c.Bytes())
|
||||
}
|
||||
|
||||
const maxConnectionIDLen = 20
|
||||
|
||||
// A ConnectionID in QUIC
|
||||
type ConnectionID struct {
|
||||
b [20]byte
|
||||
l uint8
|
||||
}
|
||||
|
||||
// GenerateConnectionID generates a connection ID using cryptographic random
|
||||
func GenerateConnectionID(l int) (ConnectionID, error) {
|
||||
var c ConnectionID
|
||||
c.l = uint8(l)
|
||||
_, err := rand.Read(c.b[:l])
|
||||
return c, err
|
||||
}
|
||||
|
||||
// ParseConnectionID interprets b as a Connection ID.
|
||||
// It panics if b is longer than 20 bytes.
|
||||
func ParseConnectionID(b []byte) ConnectionID {
|
||||
if len(b) > maxConnectionIDLen {
|
||||
panic("invalid conn id length")
|
||||
}
|
||||
var c ConnectionID
|
||||
c.l = uint8(len(b))
|
||||
copy(c.b[:c.l], b)
|
||||
return c
|
||||
}
|
||||
|
||||
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
|
||||
// It uses a length randomly chosen between 8 and 20 bytes.
|
||||
func GenerateConnectionIDForInitial() (ConnectionID, error) {
|
||||
r := make([]byte, 1)
|
||||
if _, err := rand.Read(r); err != nil {
|
||||
return ConnectionID{}, err
|
||||
}
|
||||
l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
|
||||
return GenerateConnectionID(l)
|
||||
}
|
||||
|
||||
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
|
||||
// It returns io.EOF if there are not enough bytes to read.
|
||||
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {
|
||||
var c ConnectionID
|
||||
if l == 0 {
|
||||
return c, nil
|
||||
}
|
||||
if l > maxConnectionIDLen {
|
||||
return c, ErrInvalidConnectionIDLen
|
||||
}
|
||||
c.l = uint8(l)
|
||||
_, err := io.ReadFull(r, c.b[:l])
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return c, io.EOF
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// Len returns the length of the connection ID in bytes
|
||||
func (c ConnectionID) Len() int {
|
||||
return int(c.l)
|
||||
}
|
||||
|
||||
// Bytes returns the byte representation
|
||||
func (c ConnectionID) Bytes() []byte {
|
||||
return c.b[:c.l]
|
||||
}
|
||||
|
||||
func (c ConnectionID) String() string {
|
||||
if c.Len() == 0 {
|
||||
return "(empty)"
|
||||
}
|
||||
return fmt.Sprintf("%x", c.Bytes())
|
||||
}
|
||||
|
||||
type DefaultConnectionIDGenerator struct {
|
||||
ConnLen int
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
|
||||
return GenerateConnectionID(d.ConnLen)
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
|
||||
return d.ConnLen
|
||||
}
|
30
vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go
generated
vendored
Normal file
30
vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go
generated
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
package protocol
|
||||
|
||||
// EncryptionLevel is the encryption level
|
||||
// Default value is Unencrypted
|
||||
type EncryptionLevel uint8
|
||||
|
||||
const (
|
||||
// EncryptionInitial is the Initial encryption level
|
||||
EncryptionInitial EncryptionLevel = 1 + iota
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT
|
||||
// Encryption1RTT is the 1-RTT encryption level
|
||||
Encryption1RTT
|
||||
)
|
||||
|
||||
func (e EncryptionLevel) String() string {
|
||||
switch e {
|
||||
case EncryptionInitial:
|
||||
return "Initial"
|
||||
case EncryptionHandshake:
|
||||
return "Handshake"
|
||||
case Encryption0RTT:
|
||||
return "0-RTT"
|
||||
case Encryption1RTT:
|
||||
return "1-RTT"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
36
vendor/github.com/quic-go/quic-go/internal/protocol/key_phase.go
generated
vendored
Normal file
36
vendor/github.com/quic-go/quic-go/internal/protocol/key_phase.go
generated
vendored
Normal file
|
@ -0,0 +1,36 @@
|
|||
package protocol
|
||||
|
||||
// KeyPhase is the key phase
|
||||
type KeyPhase uint64
|
||||
|
||||
// Bit determines the key phase bit
|
||||
func (p KeyPhase) Bit() KeyPhaseBit {
|
||||
if p%2 == 0 {
|
||||
return KeyPhaseZero
|
||||
}
|
||||
return KeyPhaseOne
|
||||
}
|
||||
|
||||
// KeyPhaseBit is the key phase bit
|
||||
type KeyPhaseBit uint8
|
||||
|
||||
const (
|
||||
// KeyPhaseUndefined is an undefined key phase
|
||||
KeyPhaseUndefined KeyPhaseBit = iota
|
||||
// KeyPhaseZero is key phase 0
|
||||
KeyPhaseZero
|
||||
// KeyPhaseOne is key phase 1
|
||||
KeyPhaseOne
|
||||
)
|
||||
|
||||
func (p KeyPhaseBit) String() string {
|
||||
//nolint:exhaustive
|
||||
switch p {
|
||||
case KeyPhaseZero:
|
||||
return "0"
|
||||
case KeyPhaseOne:
|
||||
return "1"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
79
vendor/github.com/quic-go/quic-go/internal/protocol/packet_number.go
generated
vendored
Normal file
79
vendor/github.com/quic-go/quic-go/internal/protocol/packet_number.go
generated
vendored
Normal file
|
@ -0,0 +1,79 @@
|
|||
package protocol
|
||||
|
||||
// A PacketNumber in QUIC
|
||||
type PacketNumber int64
|
||||
|
||||
// InvalidPacketNumber is a packet number that is never sent.
|
||||
// In QUIC, 0 is a valid packet number.
|
||||
const InvalidPacketNumber PacketNumber = -1
|
||||
|
||||
// PacketNumberLen is the length of the packet number in bytes
|
||||
type PacketNumberLen uint8
|
||||
|
||||
const (
|
||||
// PacketNumberLen1 is a packet number length of 1 byte
|
||||
PacketNumberLen1 PacketNumberLen = 1
|
||||
// PacketNumberLen2 is a packet number length of 2 bytes
|
||||
PacketNumberLen2 PacketNumberLen = 2
|
||||
// PacketNumberLen3 is a packet number length of 3 bytes
|
||||
PacketNumberLen3 PacketNumberLen = 3
|
||||
// PacketNumberLen4 is a packet number length of 4 bytes
|
||||
PacketNumberLen4 PacketNumberLen = 4
|
||||
)
|
||||
|
||||
// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
|
||||
func DecodePacketNumber(
|
||||
packetNumberLength PacketNumberLen,
|
||||
lastPacketNumber PacketNumber,
|
||||
wirePacketNumber PacketNumber,
|
||||
) PacketNumber {
|
||||
var epochDelta PacketNumber
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 8
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 16
|
||||
case PacketNumberLen3:
|
||||
epochDelta = PacketNumber(1) << 24
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 32
|
||||
}
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
var prevEpochBegin PacketNumber
|
||||
if epoch > epochDelta {
|
||||
prevEpochBegin = epoch - epochDelta
|
||||
}
|
||||
nextEpochBegin := epoch + epochDelta
|
||||
return closestTo(
|
||||
lastPacketNumber+1,
|
||||
epoch+wirePacketNumber,
|
||||
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
|
||||
)
|
||||
}
|
||||
|
||||
func closestTo(target, a, b PacketNumber) PacketNumber {
|
||||
if delta(target, a) < delta(target, b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func delta(a, b PacketNumber) PacketNumber {
|
||||
if a < b {
|
||||
return b - a
|
||||
}
|
||||
return a - b
|
||||
}
|
||||
|
||||
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
|
||||
diff := uint64(packetNumber - leastUnacked)
|
||||
if diff < (1 << (16 - 1)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
if diff < (1 << (24 - 1)) {
|
||||
return PacketNumberLen3
|
||||
}
|
||||
return PacketNumberLen4
|
||||
}
|
193
vendor/github.com/quic-go/quic-go/internal/protocol/params.go
generated
vendored
Normal file
193
vendor/github.com/quic-go/quic-go/internal/protocol/params.go
generated
vendored
Normal file
|
@ -0,0 +1,193 @@
|
|||
package protocol
|
||||
|
||||
import "time"
|
||||
|
||||
// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use.
|
||||
const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB
|
||||
|
||||
// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
|
||||
const InitialPacketSizeIPv4 = 1252
|
||||
|
||||
// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
|
||||
const InitialPacketSizeIPv6 = 1232
|
||||
|
||||
// MaxCongestionWindowPackets is the maximum congestion window in packet.
|
||||
const MaxCongestionWindowPackets = 10000
|
||||
|
||||
// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection.
|
||||
const MaxUndecryptablePackets = 32
|
||||
|
||||
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
|
||||
// This is the value that Chromium is using
|
||||
const ConnectionFlowControlMultiplier = 1.5
|
||||
|
||||
// DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data
|
||||
const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb
|
||||
|
||||
// DefaultInitialMaxData is the connection-level flow control window for receiving data
|
||||
const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData
|
||||
|
||||
// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data
|
||||
const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB
|
||||
|
||||
// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data
|
||||
const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB
|
||||
|
||||
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
|
||||
const WindowUpdateThreshold = 0.25
|
||||
|
||||
// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open
|
||||
const DefaultMaxIncomingStreams = 100
|
||||
|
||||
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
|
||||
const DefaultMaxIncomingUniStreams = 100
|
||||
|
||||
// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed.
|
||||
const MaxServerUnprocessedPackets = 1024
|
||||
|
||||
// MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed.
|
||||
const MaxConnUnprocessedPackets = 256
|
||||
|
||||
// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack.
|
||||
// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod.
|
||||
const SkipPacketInitialPeriod PacketNumber = 256
|
||||
|
||||
// SkipPacketMaxPeriod is the maximum period length used for packet number skipping.
|
||||
const SkipPacketMaxPeriod PacketNumber = 128 * 1024
|
||||
|
||||
// MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting.
|
||||
// If the queue is full, new connection attempts will be rejected.
|
||||
const MaxAcceptQueueSize = 32
|
||||
|
||||
// TokenValidity is the duration that a (non-retry) token is considered valid
|
||||
const TokenValidity = 24 * time.Hour
|
||||
|
||||
// RetryTokenValidity is the duration that a retry token is considered valid
|
||||
const RetryTokenValidity = 10 * time.Second
|
||||
|
||||
// MaxOutstandingSentPackets is maximum number of packets saved for retransmission.
|
||||
// When reached, it imposes a soft limit on sending new packets:
|
||||
// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent.
|
||||
const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets
|
||||
|
||||
// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission.
|
||||
// When reached, no more packets will be sent.
|
||||
// This value *must* be larger than MaxOutstandingSentPackets.
|
||||
const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4
|
||||
|
||||
// MaxNonAckElicitingAcks is the maximum number of packets containing an ACK,
|
||||
// but no ack-eliciting frames, that we send in a row
|
||||
const MaxNonAckElicitingAcks = 19
|
||||
|
||||
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
|
||||
// prevents DoS attacks against the streamFrameSorter
|
||||
const MaxStreamFrameSorterGaps = 1000
|
||||
|
||||
// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame
|
||||
// that we use the buffer for. This protects against a DoS where an attacker would send us
|
||||
// very small STREAM frames to consume a lot of memory.
|
||||
const MinStreamFrameBufferSize = 128
|
||||
|
||||
// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack.
|
||||
// If a packet has less than this number of bytes, we won't coalesce any more packets onto it.
|
||||
const MinCoalescedPacketSize = 128
|
||||
|
||||
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
|
||||
// This limits the size of the ClientHello and Certificates that can be received.
|
||||
const MaxCryptoStreamOffset = 16 * (1 << 10)
|
||||
|
||||
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
|
||||
const MinRemoteIdleTimeout = 5 * time.Second
|
||||
|
||||
// DefaultIdleTimeout is the default idle timeout
|
||||
const DefaultIdleTimeout = 30 * time.Second
|
||||
|
||||
// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion.
|
||||
const DefaultHandshakeIdleTimeout = 5 * time.Second
|
||||
|
||||
// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
|
||||
const DefaultHandshakeTimeout = 10 * time.Second
|
||||
|
||||
// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive.
|
||||
// It should be shorter than the time that NATs clear their mapping.
|
||||
const MaxKeepAliveInterval = 20 * time.Second
|
||||
|
||||
// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE.
|
||||
// after this time all information about the old connection will be deleted
|
||||
const RetiredConnectionIDDeleteTimeout = 5 * time.Second
|
||||
|
||||
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
|
||||
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
|
||||
// 1. it reduces the framing overhead
|
||||
// 2. it reduces the head-of-line blocking, when a packet is lost
|
||||
const MinStreamFrameSize ByteCount = 128
|
||||
|
||||
// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames
|
||||
// we send after the handshake completes.
|
||||
const MaxPostHandshakeCryptoFrameSize = 1000
|
||||
|
||||
// MaxAckFrameSize is the maximum size for an ACK frame that we write
|
||||
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
||||
// The MaxAckFrameSize should be large enough to encode many ACK range,
|
||||
// but must ensure that a maximum size ACK frame fits into one packet.
|
||||
const MaxAckFrameSize ByteCount = 1000
|
||||
|
||||
// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221).
|
||||
// The size is chosen such that a DATAGRAM frame fits into a QUIC packet.
|
||||
const MaxDatagramFrameSize ByteCount = 1200
|
||||
|
||||
// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221)
|
||||
const DatagramRcvQueueLen = 128
|
||||
|
||||
// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame.
|
||||
// It also serves as a limit for the packet history.
|
||||
// If at any point we keep track of more ranges, old ranges are discarded.
|
||||
const MaxNumAckRanges = 32
|
||||
|
||||
// MinPacingDelay is the minimum duration that is used for packet pacing
|
||||
// If the packet packing frequency is higher, multiple packets might be sent at once.
|
||||
// Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth.
|
||||
const MinPacingDelay = time.Millisecond
|
||||
|
||||
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
|
||||
// if no other value is configured.
|
||||
const DefaultConnectionIDLength = 4
|
||||
|
||||
// MaxActiveConnectionIDs is the number of connection IDs that we're storing.
|
||||
const MaxActiveConnectionIDs = 4
|
||||
|
||||
// MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time.
|
||||
const MaxIssuedConnectionIDs = 6
|
||||
|
||||
// PacketsPerConnectionID is the number of packets we send using one connection ID.
|
||||
// If the peer provices us with enough new connection IDs, we switch to a new connection ID.
|
||||
const PacketsPerConnectionID = 10000
|
||||
|
||||
// AckDelayExponent is the ack delay exponent used when sending ACKs.
|
||||
const AckDelayExponent = 3
|
||||
|
||||
// Estimated timer granularity.
|
||||
// The loss detection timer will not be set to a value smaller than granularity.
|
||||
const TimerGranularity = time.Millisecond
|
||||
|
||||
// MaxAckDelay is the maximum time by which we delay sending ACKs.
|
||||
const MaxAckDelay = 25 * time.Millisecond
|
||||
|
||||
// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity.
|
||||
// This is the value that should be advertised to the peer.
|
||||
const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity
|
||||
|
||||
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
|
||||
const KeyUpdateInterval = 100 * 1000
|
||||
|
||||
// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received.
|
||||
const Max0RTTQueueingDuration = 100 * time.Millisecond
|
||||
|
||||
// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for.
|
||||
const Max0RTTQueues = 32
|
||||
|
||||
// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection.
|
||||
// When a new connection is created, all buffered packets are passed to the connection immediately.
|
||||
// To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets.
|
||||
// To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets.
|
||||
const Max0RTTQueueLen = 31
|
26
vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go
generated
vendored
Normal file
26
vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go
generated
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
package protocol
|
||||
|
||||
// Perspective determines if we're acting as a server or a client
|
||||
type Perspective int
|
||||
|
||||
// the perspectives
|
||||
const (
|
||||
PerspectiveServer Perspective = 1
|
||||
PerspectiveClient Perspective = 2
|
||||
)
|
||||
|
||||
// Opposite returns the perspective of the peer
|
||||
func (p Perspective) Opposite() Perspective {
|
||||
return 3 - p
|
||||
}
|
||||
|
||||
func (p Perspective) String() string {
|
||||
switch p {
|
||||
case PerspectiveServer:
|
||||
return "Server"
|
||||
case PerspectiveClient:
|
||||
return "Client"
|
||||
default:
|
||||
return "invalid perspective"
|
||||
}
|
||||
}
|
97
vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go
generated
vendored
Normal file
97
vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go
generated
vendored
Normal file
|
@ -0,0 +1,97 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// The PacketType is the Long Header Type
|
||||
type PacketType uint8
|
||||
|
||||
const (
|
||||
// PacketTypeInitial is the packet type of an Initial packet
|
||||
PacketTypeInitial PacketType = 1 + iota
|
||||
// PacketTypeRetry is the packet type of a Retry packet
|
||||
PacketTypeRetry
|
||||
// PacketTypeHandshake is the packet type of a Handshake packet
|
||||
PacketTypeHandshake
|
||||
// PacketType0RTT is the packet type of a 0-RTT packet
|
||||
PacketType0RTT
|
||||
)
|
||||
|
||||
func (t PacketType) String() string {
|
||||
switch t {
|
||||
case PacketTypeInitial:
|
||||
return "Initial"
|
||||
case PacketTypeRetry:
|
||||
return "Retry"
|
||||
case PacketTypeHandshake:
|
||||
return "Handshake"
|
||||
case PacketType0RTT:
|
||||
return "0-RTT Protected"
|
||||
default:
|
||||
return fmt.Sprintf("unknown packet type: %d", t)
|
||||
}
|
||||
}
|
||||
|
||||
type ECN uint8
|
||||
|
||||
const (
|
||||
ECNNon ECN = iota // 00
|
||||
ECT1 // 01
|
||||
ECT0 // 10
|
||||
ECNCE // 11
|
||||
)
|
||||
|
||||
// A ByteCount in QUIC
|
||||
type ByteCount int64
|
||||
|
||||
// MaxByteCount is the maximum value of a ByteCount
|
||||
const MaxByteCount = ByteCount(1<<62 - 1)
|
||||
|
||||
// InvalidByteCount is an invalid byte count
|
||||
const InvalidByteCount ByteCount = -1
|
||||
|
||||
// A StatelessResetToken is a stateless reset token.
|
||||
type StatelessResetToken [16]byte
|
||||
|
||||
// MaxPacketBufferSize maximum packet size of any QUIC packet, based on
|
||||
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
|
||||
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
|
||||
// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452.
|
||||
const MaxPacketBufferSize ByteCount = 1452
|
||||
|
||||
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
|
||||
const MinInitialPacketSize = 1200
|
||||
|
||||
// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version
|
||||
// needs to have in order to trigger a Version Negotiation packet.
|
||||
const MinUnknownVersionPacketSize = MinInitialPacketSize
|
||||
|
||||
// MinStatelessResetSize is the minimum size of a stateless reset packet that we send
|
||||
const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */
|
||||
|
||||
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
||||
const MinConnectionIDLenInitial = 8
|
||||
|
||||
// DefaultAckDelayExponent is the default ack delay exponent
|
||||
const DefaultAckDelayExponent = 3
|
||||
|
||||
// MaxAckDelayExponent is the maximum ack delay exponent
|
||||
const MaxAckDelayExponent = 20
|
||||
|
||||
// DefaultMaxAckDelay is the default max_ack_delay
|
||||
const DefaultMaxAckDelay = 25 * time.Millisecond
|
||||
|
||||
// MaxMaxAckDelay is the maximum max_ack_delay
|
||||
const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond
|
||||
|
||||
// MaxConnIDLen is the maximum length of the connection ID
|
||||
const MaxConnIDLen = 20
|
||||
|
||||
// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using
|
||||
// AEAD_AES_128_GCM or AEAD_AES_265_GCM.
|
||||
const InvalidPacketLimitAES = 1 << 52
|
||||
|
||||
// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305.
|
||||
const InvalidPacketLimitChaCha = 1 << 36
|
76
vendor/github.com/quic-go/quic-go/internal/protocol/stream.go
generated
vendored
Normal file
76
vendor/github.com/quic-go/quic-go/internal/protocol/stream.go
generated
vendored
Normal file
|
@ -0,0 +1,76 @@
|
|||
package protocol
|
||||
|
||||
// StreamType encodes if this is a unidirectional or bidirectional stream
|
||||
type StreamType uint8
|
||||
|
||||
const (
|
||||
// StreamTypeUni is a unidirectional stream
|
||||
StreamTypeUni StreamType = iota
|
||||
// StreamTypeBidi is a bidirectional stream
|
||||
StreamTypeBidi
|
||||
)
|
||||
|
||||
// InvalidPacketNumber is a stream ID that is invalid.
|
||||
// The first valid stream ID in QUIC is 0.
|
||||
const InvalidStreamID StreamID = -1
|
||||
|
||||
// StreamNum is the stream number
|
||||
type StreamNum int64
|
||||
|
||||
const (
|
||||
// InvalidStreamNum is an invalid stream number.
|
||||
InvalidStreamNum = -1
|
||||
// MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames
|
||||
// and as the stream count in the transport parameters
|
||||
MaxStreamCount StreamNum = 1 << 60
|
||||
)
|
||||
|
||||
// StreamID calculates the stream ID.
|
||||
func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID {
|
||||
if s == 0 {
|
||||
return InvalidStreamID
|
||||
}
|
||||
var first StreamID
|
||||
switch stype {
|
||||
case StreamTypeBidi:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 0
|
||||
case PerspectiveServer:
|
||||
first = 1
|
||||
}
|
||||
case StreamTypeUni:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 2
|
||||
case PerspectiveServer:
|
||||
first = 3
|
||||
}
|
||||
}
|
||||
return first + 4*StreamID(s-1)
|
||||
}
|
||||
|
||||
// A StreamID in QUIC
|
||||
type StreamID int64
|
||||
|
||||
// InitiatedBy says if the stream was initiated by the client or by the server
|
||||
func (s StreamID) InitiatedBy() Perspective {
|
||||
if s%2 == 0 {
|
||||
return PerspectiveClient
|
||||
}
|
||||
return PerspectiveServer
|
||||
}
|
||||
|
||||
// Type says if this is a unidirectional or bidirectional stream
|
||||
func (s StreamID) Type() StreamType {
|
||||
if s%4 >= 2 {
|
||||
return StreamTypeUni
|
||||
}
|
||||
return StreamTypeBidi
|
||||
}
|
||||
|
||||
// StreamNum returns how many streams in total are below this
|
||||
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
|
||||
func (s StreamID) StreamNum() StreamNum {
|
||||
return StreamNum(s/4) + 1
|
||||
}
|
114
vendor/github.com/quic-go/quic-go/internal/protocol/version.go
generated
vendored
Normal file
114
vendor/github.com/quic-go/quic-go/internal/protocol/version.go
generated
vendored
Normal file
|
@ -0,0 +1,114 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
// VersionNumber is a version number as int
|
||||
type VersionNumber uint32
|
||||
|
||||
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
|
||||
const (
|
||||
gquicVersion0 = 0x51303030
|
||||
maxGquicVersion = 0x51303439
|
||||
)
|
||||
|
||||
// The version numbers, making grepping easier
|
||||
const (
|
||||
VersionTLS VersionNumber = 0x1
|
||||
VersionWhatever VersionNumber = math.MaxUint32 - 1 // for when the version doesn't matter
|
||||
VersionUnknown VersionNumber = math.MaxUint32
|
||||
VersionDraft29 VersionNumber = 0xff00001d
|
||||
Version1 VersionNumber = 0x1
|
||||
Version2 VersionNumber = 0x6b3343cf
|
||||
)
|
||||
|
||||
// SupportedVersions lists the versions that the server supports
|
||||
// must be in sorted descending order
|
||||
var SupportedVersions = []VersionNumber{Version1, Version2, VersionDraft29}
|
||||
|
||||
// IsValidVersion says if the version is known to quic-go
|
||||
func IsValidVersion(v VersionNumber) bool {
|
||||
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
|
||||
}
|
||||
|
||||
func (vn VersionNumber) String() string {
|
||||
// For releases, VersionTLS will be set to a draft version.
|
||||
// A switch statement can't contain duplicate cases.
|
||||
if vn == VersionTLS && VersionTLS != VersionDraft29 && VersionTLS != Version1 {
|
||||
return "TLS dev version (WIP)"
|
||||
}
|
||||
//nolint:exhaustive
|
||||
switch vn {
|
||||
case VersionWhatever:
|
||||
return "whatever"
|
||||
case VersionUnknown:
|
||||
return "unknown"
|
||||
case VersionDraft29:
|
||||
return "draft-29"
|
||||
case Version1:
|
||||
return "v1"
|
||||
case Version2:
|
||||
return "v2"
|
||||
default:
|
||||
if vn.isGQUIC() {
|
||||
return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion())
|
||||
}
|
||||
return fmt.Sprintf("%#x", uint32(vn))
|
||||
}
|
||||
}
|
||||
|
||||
func (vn VersionNumber) isGQUIC() bool {
|
||||
return vn > gquicVersion0 && vn <= maxGquicVersion
|
||||
}
|
||||
|
||||
func (vn VersionNumber) toGQUICVersion() int {
|
||||
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
|
||||
}
|
||||
|
||||
// IsSupportedVersion returns true if the server supports this version
|
||||
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
|
||||
for _, t := range supported {
|
||||
if t == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ChooseSupportedVersion finds the best version in the overlap of ours and theirs
|
||||
// ours is a slice of versions that we support, sorted by our preference (descending)
|
||||
// theirs is a slice of versions offered by the peer. The order does not matter.
|
||||
// The bool returned indicates if a matching version was found.
|
||||
func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) {
|
||||
for _, ourVer := range ours {
|
||||
for _, theirVer := range theirs {
|
||||
if ourVer == theirVer {
|
||||
return ourVer, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a)
|
||||
func generateReservedVersion() VersionNumber {
|
||||
b := make([]byte, 4)
|
||||
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
|
||||
return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
|
||||
}
|
||||
|
||||
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position
|
||||
func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
|
||||
b := make([]byte, 1)
|
||||
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
|
||||
randPos := int(b[0]) % (len(supported) + 1)
|
||||
greased := make([]VersionNumber, len(supported)+1)
|
||||
copy(greased, supported[:randPos])
|
||||
greased[randPos] = generateReservedVersion()
|
||||
copy(greased[randPos+1:], supported[randPos:])
|
||||
return greased
|
||||
}
|
88
vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go
generated
vendored
Normal file
88
vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go
generated
vendored
Normal file
|
@ -0,0 +1,88 @@
|
|||
package qerr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
// TransportErrorCode is a QUIC transport error.
|
||||
type TransportErrorCode uint64
|
||||
|
||||
// The error codes defined by QUIC
|
||||
const (
|
||||
NoError TransportErrorCode = 0x0
|
||||
InternalError TransportErrorCode = 0x1
|
||||
ConnectionRefused TransportErrorCode = 0x2
|
||||
FlowControlError TransportErrorCode = 0x3
|
||||
StreamLimitError TransportErrorCode = 0x4
|
||||
StreamStateError TransportErrorCode = 0x5
|
||||
FinalSizeError TransportErrorCode = 0x6
|
||||
FrameEncodingError TransportErrorCode = 0x7
|
||||
TransportParameterError TransportErrorCode = 0x8
|
||||
ConnectionIDLimitError TransportErrorCode = 0x9
|
||||
ProtocolViolation TransportErrorCode = 0xa
|
||||
InvalidToken TransportErrorCode = 0xb
|
||||
ApplicationErrorErrorCode TransportErrorCode = 0xc
|
||||
CryptoBufferExceeded TransportErrorCode = 0xd
|
||||
KeyUpdateError TransportErrorCode = 0xe
|
||||
AEADLimitReached TransportErrorCode = 0xf
|
||||
NoViablePathError TransportErrorCode = 0x10
|
||||
)
|
||||
|
||||
func (e TransportErrorCode) IsCryptoError() bool {
|
||||
return e >= 0x100 && e < 0x200
|
||||
}
|
||||
|
||||
// Message is a description of the error.
|
||||
// It only returns a non-empty string for crypto errors.
|
||||
func (e TransportErrorCode) Message() string {
|
||||
if !e.IsCryptoError() {
|
||||
return ""
|
||||
}
|
||||
return qtls.Alert(e - 0x100).Error()
|
||||
}
|
||||
|
||||
func (e TransportErrorCode) String() string {
|
||||
switch e {
|
||||
case NoError:
|
||||
return "NO_ERROR"
|
||||
case InternalError:
|
||||
return "INTERNAL_ERROR"
|
||||
case ConnectionRefused:
|
||||
return "CONNECTION_REFUSED"
|
||||
case FlowControlError:
|
||||
return "FLOW_CONTROL_ERROR"
|
||||
case StreamLimitError:
|
||||
return "STREAM_LIMIT_ERROR"
|
||||
case StreamStateError:
|
||||
return "STREAM_STATE_ERROR"
|
||||
case FinalSizeError:
|
||||
return "FINAL_SIZE_ERROR"
|
||||
case FrameEncodingError:
|
||||
return "FRAME_ENCODING_ERROR"
|
||||
case TransportParameterError:
|
||||
return "TRANSPORT_PARAMETER_ERROR"
|
||||
case ConnectionIDLimitError:
|
||||
return "CONNECTION_ID_LIMIT_ERROR"
|
||||
case ProtocolViolation:
|
||||
return "PROTOCOL_VIOLATION"
|
||||
case InvalidToken:
|
||||
return "INVALID_TOKEN"
|
||||
case ApplicationErrorErrorCode:
|
||||
return "APPLICATION_ERROR"
|
||||
case CryptoBufferExceeded:
|
||||
return "CRYPTO_BUFFER_EXCEEDED"
|
||||
case KeyUpdateError:
|
||||
return "KEY_UPDATE_ERROR"
|
||||
case AEADLimitReached:
|
||||
return "AEAD_LIMIT_REACHED"
|
||||
case NoViablePathError:
|
||||
return "NO_VIABLE_PATH"
|
||||
default:
|
||||
if e.IsCryptoError() {
|
||||
return fmt.Sprintf("CRYPTO_ERROR %#x", uint16(e))
|
||||
}
|
||||
return fmt.Sprintf("unknown error code: %#x", uint16(e))
|
||||
}
|
||||
}
|
131
vendor/github.com/quic-go/quic-go/internal/qerr/errors.go
generated
vendored
Normal file
131
vendor/github.com/quic-go/quic-go/internal/qerr/errors.go
generated
vendored
Normal file
|
@ -0,0 +1,131 @@
|
|||
package qerr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrHandshakeTimeout = &HandshakeTimeoutError{}
|
||||
ErrIdleTimeout = &IdleTimeoutError{}
|
||||
)
|
||||
|
||||
type TransportError struct {
|
||||
Remote bool
|
||||
FrameType uint64
|
||||
ErrorCode TransportErrorCode
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
var _ error = &TransportError{}
|
||||
|
||||
// NewLocalCryptoError create a new TransportError instance for a crypto error
|
||||
func NewLocalCryptoError(tlsAlert uint8, errorMessage string) *TransportError {
|
||||
return &TransportError{
|
||||
ErrorCode: 0x100 + TransportErrorCode(tlsAlert),
|
||||
ErrorMessage: errorMessage,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *TransportError) Error() string {
|
||||
str := fmt.Sprintf("%s (%s)", e.ErrorCode.String(), getRole(e.Remote))
|
||||
if e.FrameType != 0 {
|
||||
str += fmt.Sprintf(" (frame type: %#x)", e.FrameType)
|
||||
}
|
||||
msg := e.ErrorMessage
|
||||
if len(msg) == 0 {
|
||||
msg = e.ErrorCode.Message()
|
||||
}
|
||||
if len(msg) == 0 {
|
||||
return str
|
||||
}
|
||||
return str + ": " + msg
|
||||
}
|
||||
|
||||
func (e *TransportError) Is(target error) bool {
|
||||
return target == net.ErrClosed
|
||||
}
|
||||
|
||||
// An ApplicationErrorCode is an application-defined error code.
|
||||
type ApplicationErrorCode uint64
|
||||
|
||||
func (e *ApplicationError) Is(target error) bool {
|
||||
return target == net.ErrClosed
|
||||
}
|
||||
|
||||
// A StreamErrorCode is an error code used to cancel streams.
|
||||
type StreamErrorCode uint64
|
||||
|
||||
type ApplicationError struct {
|
||||
Remote bool
|
||||
ErrorCode ApplicationErrorCode
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
var _ error = &ApplicationError{}
|
||||
|
||||
func (e *ApplicationError) Error() string {
|
||||
if len(e.ErrorMessage) == 0 {
|
||||
return fmt.Sprintf("Application error %#x (%s)", e.ErrorCode, getRole(e.Remote))
|
||||
}
|
||||
return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage)
|
||||
}
|
||||
|
||||
type IdleTimeoutError struct{}
|
||||
|
||||
var _ error = &IdleTimeoutError{}
|
||||
|
||||
func (e *IdleTimeoutError) Timeout() bool { return true }
|
||||
func (e *IdleTimeoutError) Temporary() bool { return false }
|
||||
func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" }
|
||||
func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed }
|
||||
|
||||
type HandshakeTimeoutError struct{}
|
||||
|
||||
var _ error = &HandshakeTimeoutError{}
|
||||
|
||||
func (e *HandshakeTimeoutError) Timeout() bool { return true }
|
||||
func (e *HandshakeTimeoutError) Temporary() bool { return false }
|
||||
func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" }
|
||||
func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed }
|
||||
|
||||
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
|
||||
type VersionNegotiationError struct {
|
||||
Ours []protocol.VersionNumber
|
||||
Theirs []protocol.VersionNumber
|
||||
}
|
||||
|
||||
func (e *VersionNegotiationError) Error() string {
|
||||
return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs)
|
||||
}
|
||||
|
||||
func (e *VersionNegotiationError) Is(target error) bool {
|
||||
return target == net.ErrClosed
|
||||
}
|
||||
|
||||
// A StatelessResetError occurs when we receive a stateless reset.
|
||||
type StatelessResetError struct {
|
||||
Token protocol.StatelessResetToken
|
||||
}
|
||||
|
||||
var _ net.Error = &StatelessResetError{}
|
||||
|
||||
func (e *StatelessResetError) Error() string {
|
||||
return fmt.Sprintf("received a stateless reset with token %x", e.Token)
|
||||
}
|
||||
|
||||
func (e *StatelessResetError) Is(target error) bool {
|
||||
return target == net.ErrClosed
|
||||
}
|
||||
|
||||
func (e *StatelessResetError) Timeout() bool { return false }
|
||||
func (e *StatelessResetError) Temporary() bool { return true }
|
||||
|
||||
func getRole(remote bool) string {
|
||||
if remote {
|
||||
return "remote"
|
||||
}
|
||||
return "local"
|
||||
}
|
99
vendor/github.com/quic-go/quic-go/internal/qtls/go118.go
generated
vendored
Normal file
99
vendor/github.com/quic-go/quic-go/internal/qtls/go118.go
generated
vendored
Normal file
|
@ -0,0 +1,99 @@
|
|||
//go:build go1.18 && !go1.19
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/quic-go/qtls-go1-18"
|
||||
)
|
||||
|
||||
type (
|
||||
// Alert is a TLS alert
|
||||
Alert = qtls.Alert
|
||||
// A Certificate is qtls.Certificate.
|
||||
Certificate = qtls.Certificate
|
||||
// CertificateRequestInfo contains inforamtion about a certificate request.
|
||||
CertificateRequestInfo = qtls.CertificateRequestInfo
|
||||
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
|
||||
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
|
||||
// ClientHelloInfo contains information about a ClientHello.
|
||||
ClientHelloInfo = qtls.ClientHelloInfo
|
||||
// ClientSessionCache is a cache used for session resumption.
|
||||
ClientSessionCache = qtls.ClientSessionCache
|
||||
// ClientSessionState is a state needed for session resumption.
|
||||
ClientSessionState = qtls.ClientSessionState
|
||||
// A Config is a qtls.Config.
|
||||
Config = qtls.Config
|
||||
// A Conn is a qtls.Conn.
|
||||
Conn = qtls.Conn
|
||||
// ConnectionState contains information about the state of the connection.
|
||||
ConnectionState = qtls.ConnectionStateWith0RTT
|
||||
// EncryptionLevel is the encryption level of a message.
|
||||
EncryptionLevel = qtls.EncryptionLevel
|
||||
// Extension is a TLS extension
|
||||
Extension = qtls.Extension
|
||||
// ExtraConfig is the qtls.ExtraConfig
|
||||
ExtraConfig = qtls.ExtraConfig
|
||||
// RecordLayer is a qtls RecordLayer.
|
||||
RecordLayer = qtls.RecordLayer
|
||||
)
|
||||
|
||||
const (
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake = qtls.EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT = qtls.Encryption0RTT
|
||||
// EncryptionApplication is the application data encryption level
|
||||
EncryptionApplication = qtls.EncryptionApplication
|
||||
)
|
||||
|
||||
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
|
||||
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
|
||||
return qtls.AEADAESGCMTLS13(key, fixedNonce)
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection.
|
||||
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Client(conn, config, extraConfig)
|
||||
}
|
||||
|
||||
// Server returns a new TLS server side connection.
|
||||
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Server(conn, config, extraConfig)
|
||||
}
|
||||
|
||||
func GetConnectionState(conn *Conn) ConnectionState {
|
||||
return conn.ConnectionStateWith0RTT()
|
||||
}
|
||||
|
||||
// ToTLSConnectionState extracts the tls.ConnectionState
|
||||
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
|
||||
return cs.ConnectionState
|
||||
}
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-18.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
|
||||
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
|
||||
val := cipherSuiteTLS13ByID(id)
|
||||
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
|
||||
return &qtls.CipherSuiteTLS13{
|
||||
ID: cs.ID,
|
||||
KeyLen: cs.KeyLen,
|
||||
AEAD: cs.AEAD,
|
||||
Hash: cs.Hash,
|
||||
}
|
||||
}
|
99
vendor/github.com/quic-go/quic-go/internal/qtls/go119.go
generated
vendored
Normal file
99
vendor/github.com/quic-go/quic-go/internal/qtls/go119.go
generated
vendored
Normal file
|
@ -0,0 +1,99 @@
|
|||
//go:build go1.19 && !go1.20
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/quic-go/qtls-go1-19"
|
||||
)
|
||||
|
||||
type (
|
||||
// Alert is a TLS alert
|
||||
Alert = qtls.Alert
|
||||
// A Certificate is qtls.Certificate.
|
||||
Certificate = qtls.Certificate
|
||||
// CertificateRequestInfo contains information about a certificate request.
|
||||
CertificateRequestInfo = qtls.CertificateRequestInfo
|
||||
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
|
||||
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
|
||||
// ClientHelloInfo contains information about a ClientHello.
|
||||
ClientHelloInfo = qtls.ClientHelloInfo
|
||||
// ClientSessionCache is a cache used for session resumption.
|
||||
ClientSessionCache = qtls.ClientSessionCache
|
||||
// ClientSessionState is a state needed for session resumption.
|
||||
ClientSessionState = qtls.ClientSessionState
|
||||
// A Config is a qtls.Config.
|
||||
Config = qtls.Config
|
||||
// A Conn is a qtls.Conn.
|
||||
Conn = qtls.Conn
|
||||
// ConnectionState contains information about the state of the connection.
|
||||
ConnectionState = qtls.ConnectionStateWith0RTT
|
||||
// EncryptionLevel is the encryption level of a message.
|
||||
EncryptionLevel = qtls.EncryptionLevel
|
||||
// Extension is a TLS extension
|
||||
Extension = qtls.Extension
|
||||
// ExtraConfig is the qtls.ExtraConfig
|
||||
ExtraConfig = qtls.ExtraConfig
|
||||
// RecordLayer is a qtls RecordLayer.
|
||||
RecordLayer = qtls.RecordLayer
|
||||
)
|
||||
|
||||
const (
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake = qtls.EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT = qtls.Encryption0RTT
|
||||
// EncryptionApplication is the application data encryption level
|
||||
EncryptionApplication = qtls.EncryptionApplication
|
||||
)
|
||||
|
||||
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
|
||||
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
|
||||
return qtls.AEADAESGCMTLS13(key, fixedNonce)
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection.
|
||||
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Client(conn, config, extraConfig)
|
||||
}
|
||||
|
||||
// Server returns a new TLS server side connection.
|
||||
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Server(conn, config, extraConfig)
|
||||
}
|
||||
|
||||
func GetConnectionState(conn *Conn) ConnectionState {
|
||||
return conn.ConnectionStateWith0RTT()
|
||||
}
|
||||
|
||||
// ToTLSConnectionState extracts the tls.ConnectionState
|
||||
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
|
||||
return cs.ConnectionState
|
||||
}
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-19.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
|
||||
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
|
||||
val := cipherSuiteTLS13ByID(id)
|
||||
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
|
||||
return &qtls.CipherSuiteTLS13{
|
||||
ID: cs.ID,
|
||||
KeyLen: cs.KeyLen,
|
||||
AEAD: cs.AEAD,
|
||||
Hash: cs.Hash,
|
||||
}
|
||||
}
|
99
vendor/github.com/quic-go/quic-go/internal/qtls/go120.go
generated
vendored
Normal file
99
vendor/github.com/quic-go/quic-go/internal/qtls/go120.go
generated
vendored
Normal file
|
@ -0,0 +1,99 @@
|
|||
//go:build go1.20
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/quic-go/qtls-go1-20"
|
||||
)
|
||||
|
||||
type (
|
||||
// Alert is a TLS alert
|
||||
Alert = qtls.Alert
|
||||
// A Certificate is qtls.Certificate.
|
||||
Certificate = qtls.Certificate
|
||||
// CertificateRequestInfo contains information about a certificate request.
|
||||
CertificateRequestInfo = qtls.CertificateRequestInfo
|
||||
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
|
||||
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
|
||||
// ClientHelloInfo contains information about a ClientHello.
|
||||
ClientHelloInfo = qtls.ClientHelloInfo
|
||||
// ClientSessionCache is a cache used for session resumption.
|
||||
ClientSessionCache = qtls.ClientSessionCache
|
||||
// ClientSessionState is a state needed for session resumption.
|
||||
ClientSessionState = qtls.ClientSessionState
|
||||
// A Config is a qtls.Config.
|
||||
Config = qtls.Config
|
||||
// A Conn is a qtls.Conn.
|
||||
Conn = qtls.Conn
|
||||
// ConnectionState contains information about the state of the connection.
|
||||
ConnectionState = qtls.ConnectionStateWith0RTT
|
||||
// EncryptionLevel is the encryption level of a message.
|
||||
EncryptionLevel = qtls.EncryptionLevel
|
||||
// Extension is a TLS extension
|
||||
Extension = qtls.Extension
|
||||
// ExtraConfig is the qtls.ExtraConfig
|
||||
ExtraConfig = qtls.ExtraConfig
|
||||
// RecordLayer is a qtls RecordLayer.
|
||||
RecordLayer = qtls.RecordLayer
|
||||
)
|
||||
|
||||
const (
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake = qtls.EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT = qtls.Encryption0RTT
|
||||
// EncryptionApplication is the application data encryption level
|
||||
EncryptionApplication = qtls.EncryptionApplication
|
||||
)
|
||||
|
||||
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
|
||||
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
|
||||
return qtls.AEADAESGCMTLS13(key, fixedNonce)
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection.
|
||||
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Client(conn, config, extraConfig)
|
||||
}
|
||||
|
||||
// Server returns a new TLS server side connection.
|
||||
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
return qtls.Server(conn, config, extraConfig)
|
||||
}
|
||||
|
||||
func GetConnectionState(conn *Conn) ConnectionState {
|
||||
return conn.ConnectionStateWith0RTT()
|
||||
}
|
||||
|
||||
// ToTLSConnectionState extracts the tls.ConnectionState
|
||||
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
|
||||
return cs.ConnectionState
|
||||
}
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-20.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
|
||||
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
|
||||
val := cipherSuiteTLS13ByID(id)
|
||||
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
|
||||
return &qtls.CipherSuiteTLS13{
|
||||
ID: cs.ID,
|
||||
KeyLen: cs.KeyLen,
|
||||
AEAD: cs.AEAD,
|
||||
Hash: cs.Hash,
|
||||
}
|
||||
}
|
5
vendor/github.com/quic-go/quic-go/internal/qtls/go121.go
generated
vendored
Normal file
5
vendor/github.com/quic-go/quic-go/internal/qtls/go121.go
generated
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
var _ int = "The version of quic-go you're using can't be built on Go 1.21 yet. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions."
|
5
vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go
generated
vendored
Normal file
5
vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go
generated
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
//go:build !go1.18
|
||||
|
||||
package qtls
|
||||
|
||||
var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions."
|
22
vendor/github.com/quic-go/quic-go/internal/utils/atomic_bool.go
generated
vendored
Normal file
22
vendor/github.com/quic-go/quic-go/internal/utils/atomic_bool.go
generated
vendored
Normal file
|
@ -0,0 +1,22 @@
|
|||
package utils
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// An AtomicBool is an atomic bool
|
||||
type AtomicBool struct {
|
||||
v int32
|
||||
}
|
||||
|
||||
// Set sets the value
|
||||
func (a *AtomicBool) Set(value bool) {
|
||||
var n int32
|
||||
if value {
|
||||
n = 1
|
||||
}
|
||||
atomic.StoreInt32(&a.v, n)
|
||||
}
|
||||
|
||||
// Get gets the value
|
||||
func (a *AtomicBool) Get() bool {
|
||||
return atomic.LoadInt32(&a.v) != 0
|
||||
}
|
26
vendor/github.com/quic-go/quic-go/internal/utils/buffered_write_closer.go
generated
vendored
Normal file
26
vendor/github.com/quic-go/quic-go/internal/utils/buffered_write_closer.go
generated
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
type bufferedWriteCloser struct {
|
||||
*bufio.Writer
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer
|
||||
func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser {
|
||||
return &bufferedWriteCloser{
|
||||
Writer: writer,
|
||||
Closer: closer,
|
||||
}
|
||||
}
|
||||
|
||||
func (h bufferedWriteCloser) Close() error {
|
||||
if err := h.Writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
return h.Closer.Close()
|
||||
}
|
21
vendor/github.com/quic-go/quic-go/internal/utils/byteorder.go
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/internal/utils/byteorder.go
generated
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
|
||||
type ByteOrder interface {
|
||||
Uint32([]byte) uint32
|
||||
Uint24([]byte) uint32
|
||||
Uint16([]byte) uint16
|
||||
|
||||
ReadUint32(io.ByteReader) (uint32, error)
|
||||
ReadUint24(io.ByteReader) (uint32, error)
|
||||
ReadUint16(io.ByteReader) (uint16, error)
|
||||
|
||||
WriteUint32(*bytes.Buffer, uint32)
|
||||
WriteUint24(*bytes.Buffer, uint32)
|
||||
WriteUint16(*bytes.Buffer, uint16)
|
||||
}
|
103
vendor/github.com/quic-go/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
Normal file
103
vendor/github.com/quic-go/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
Normal file
|
@ -0,0 +1,103 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
// BigEndian is the big-endian implementation of ByteOrder.
|
||||
var BigEndian ByteOrder = bigEndian{}
|
||||
|
||||
type bigEndian struct{}
|
||||
|
||||
var _ ByteOrder = &bigEndian{}
|
||||
|
||||
// ReadUintN reads N bytes
|
||||
func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
|
||||
var res uint64
|
||||
for i := uint8(0); i < length; i++ {
|
||||
bt, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res ^= uint64(bt) << ((length - 1 - i) * 8)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// ReadUint32 reads a uint32
|
||||
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
|
||||
var b1, b2, b3, b4 uint8
|
||||
var err error
|
||||
if b4, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
|
||||
}
|
||||
|
||||
// ReadUint24 reads a uint24
|
||||
func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) {
|
||||
var b1, b2, b3 uint8
|
||||
var err error
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil
|
||||
}
|
||||
|
||||
// ReadUint16 reads a uint16
|
||||
func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
|
||||
var b1, b2 uint8
|
||||
var err error
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(b1) + uint16(b2)<<8, nil
|
||||
}
|
||||
|
||||
func (bigEndian) Uint32(b []byte) uint32 {
|
||||
return binary.BigEndian.Uint32(b)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint24(b []byte) uint32 {
|
||||
_ = b[2] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16
|
||||
}
|
||||
|
||||
func (bigEndian) Uint16(b []byte) uint16 {
|
||||
return binary.BigEndian.Uint16(b)
|
||||
}
|
||||
|
||||
// WriteUint32 writes a uint32
|
||||
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
||||
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
// WriteUint24 writes a uint24
|
||||
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
|
||||
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
// WriteUint16 writes a uint16
|
||||
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
||||
b.Write([]byte{uint8(i >> 8), uint8(i)})
|
||||
}
|
10
vendor/github.com/quic-go/quic-go/internal/utils/ip.go
generated
vendored
Normal file
10
vendor/github.com/quic-go/quic-go/internal/utils/ip.go
generated
vendored
Normal file
|
@ -0,0 +1,10 @@
|
|||
package utils
|
||||
|
||||
import "net"
|
||||
|
||||
func IsIPv4(ip net.IP) bool {
|
||||
// If ip is not an IPv4 address, To4 returns nil.
|
||||
// Note that there might be some corner cases, where this is not correct.
|
||||
// See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
|
||||
return ip.To4() != nil
|
||||
}
|
6
vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/README.md
generated
vendored
Normal file
6
vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/README.md
generated
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
# Usage
|
||||
|
||||
This is the Go standard library implementation of a linked list
|
||||
(https://golang.org/src/container/list/list.go), with the following modifications:
|
||||
* it uses Go generics
|
||||
* it allows passing in a `sync.Pool` (via the `NewWithPool` constructor) to reduce allocations of `Element` structs
|
264
vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/linkedlist.go
generated
vendored
Normal file
264
vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/linkedlist.go
generated
vendored
Normal file
|
@ -0,0 +1,264 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package list implements a doubly linked list.
|
||||
//
|
||||
// To iterate over a list (where l is a *List[T]):
|
||||
//
|
||||
// for e := l.Front(); e != nil; e = e.Next() {
|
||||
// // do something with e.Value
|
||||
// }
|
||||
package list
|
||||
|
||||
import "sync"
|
||||
|
||||
func NewPool[T any]() *sync.Pool {
|
||||
return &sync.Pool{New: func() any { return &Element[T]{} }}
|
||||
}
|
||||
|
||||
// Element is an element of a linked list.
|
||||
type Element[T any] struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *Element[T]
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *List[T]
|
||||
|
||||
// The value stored with this element.
|
||||
Value T
|
||||
}
|
||||
|
||||
// Next returns the next list element or nil.
|
||||
func (e *Element[T]) Next() *Element[T] {
|
||||
if p := e.next; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prev returns the previous list element or nil.
|
||||
func (e *Element[T]) Prev() *Element[T] {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Element[T]) List() *List[T] {
|
||||
return e.list
|
||||
}
|
||||
|
||||
// List represents a doubly linked list.
|
||||
// The zero value for List is an empty list ready to use.
|
||||
type List[T any] struct {
|
||||
root Element[T] // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list length excluding (this) sentinel element
|
||||
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *List[T]) Init() *List[T] {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// New returns an initialized list.
|
||||
func New[T any]() *List[T] { return new(List[T]).Init() }
|
||||
|
||||
// NewWithPool returns an initialized list, using a sync.Pool for list elements.
|
||||
func NewWithPool[T any](pool *sync.Pool) *List[T] {
|
||||
l := &List[T]{pool: pool}
|
||||
return l.Init()
|
||||
}
|
||||
|
||||
// Len returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *List[T]) Len() int { return l.len }
|
||||
|
||||
// Front returns the first element of list l or nil if the list is empty.
|
||||
func (l *List[T]) Front() *Element[T] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.next
|
||||
}
|
||||
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *List[T]) Back() *Element[T] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero List value.
|
||||
func (l *List[T]) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
}
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *List[T]) insert(e, at *Element[T]) *Element[T] {
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
|
||||
func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] {
|
||||
var e *Element[T]
|
||||
if l.pool != nil {
|
||||
e = l.pool.Get().(*Element[T])
|
||||
} else {
|
||||
e = &Element[T]{}
|
||||
}
|
||||
e.Value = v
|
||||
return l.insert(e, at)
|
||||
}
|
||||
|
||||
// remove removes e from its list, decrements l.len
|
||||
func (l *List[T]) remove(e *Element[T]) {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
if l.pool != nil {
|
||||
l.pool.Put(e)
|
||||
}
|
||||
l.len--
|
||||
}
|
||||
|
||||
// move moves e to next to at.
|
||||
func (l *List[T]) move(e, at *Element[T]) {
|
||||
if e == at {
|
||||
return
|
||||
}
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
}
|
||||
|
||||
// Remove removes e from l if e is an element of list l.
|
||||
// It returns the element value e.Value.
|
||||
// The element must not be nil.
|
||||
func (l *List[T]) Remove(e *Element[T]) T {
|
||||
v := e.Value
|
||||
if e.list == l {
|
||||
// if e.list == l, l must have been initialized when e was inserted
|
||||
// in l or l == nil (e is a zero Element) and l.remove will crash
|
||||
l.remove(e)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *List[T]) PushFront(v T) *Element[T] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(v, &l.root)
|
||||
}
|
||||
|
||||
// PushBack inserts a new element e with value v at the back of list l and returns e.
|
||||
func (l *List[T]) PushBack(v T) *Element[T] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(v, l.root.prev)
|
||||
}
|
||||
|
||||
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark.prev)
|
||||
}
|
||||
|
||||
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *List[T]) MoveToFront(e *Element[T]) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.move(e, &l.root)
|
||||
}
|
||||
|
||||
// MoveToBack moves element e to the back of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *List[T]) MoveToBack(e *Element[T]) {
|
||||
if e.list != l || l.root.prev == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.move(e, l.root.prev)
|
||||
}
|
||||
|
||||
// MoveBefore moves element e to its new position before mark.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *List[T]) MoveBefore(e, mark *Element[T]) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
}
|
||||
l.move(e, mark.prev)
|
||||
}
|
||||
|
||||
// MoveAfter moves element e to its new position after mark.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *List[T]) MoveAfter(e, mark *Element[T]) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
}
|
||||
l.move(e, mark)
|
||||
}
|
||||
|
||||
// PushBackList inserts a copy of another list at the back of list l.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *List[T]) PushBackList(other *List[T]) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
||||
l.insertValue(e.Value, l.root.prev)
|
||||
}
|
||||
}
|
||||
|
||||
// PushFrontList inserts a copy of another list at the front of list l.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *List[T]) PushFrontList(other *List[T]) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
||||
l.insertValue(e.Value, &l.root)
|
||||
}
|
||||
}
|
131
vendor/github.com/quic-go/quic-go/internal/utils/log.go
generated
vendored
Normal file
131
vendor/github.com/quic-go/quic-go/internal/utils/log.go
generated
vendored
Normal file
|
@ -0,0 +1,131 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogLevel of quic-go
|
||||
type LogLevel uint8
|
||||
|
||||
const (
|
||||
// LogLevelNothing disables
|
||||
LogLevelNothing LogLevel = iota
|
||||
// LogLevelError enables err logs
|
||||
LogLevelError
|
||||
// LogLevelInfo enables info logs (e.g. packets)
|
||||
LogLevelInfo
|
||||
// LogLevelDebug enables debug logs (e.g. packet contents)
|
||||
LogLevelDebug
|
||||
)
|
||||
|
||||
const logEnv = "QUIC_GO_LOG_LEVEL"
|
||||
|
||||
// A Logger logs.
|
||||
type Logger interface {
|
||||
SetLogLevel(LogLevel)
|
||||
SetLogTimeFormat(format string)
|
||||
WithPrefix(prefix string) Logger
|
||||
Debug() bool
|
||||
|
||||
Errorf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Debugf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// DefaultLogger is used by quic-go for logging.
|
||||
var DefaultLogger Logger
|
||||
|
||||
type defaultLogger struct {
|
||||
prefix string
|
||||
|
||||
logLevel LogLevel
|
||||
timeFormat string
|
||||
}
|
||||
|
||||
var _ Logger = &defaultLogger{}
|
||||
|
||||
// SetLogLevel sets the log level
|
||||
func (l *defaultLogger) SetLogLevel(level LogLevel) {
|
||||
l.logLevel = level
|
||||
}
|
||||
|
||||
// SetLogTimeFormat sets the format of the timestamp
|
||||
// an empty string disables the logging of timestamps
|
||||
func (l *defaultLogger) SetLogTimeFormat(format string) {
|
||||
log.SetFlags(0) // disable timestamp logging done by the log package
|
||||
l.timeFormat = format
|
||||
}
|
||||
|
||||
// Debugf logs something
|
||||
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
|
||||
if l.logLevel == LogLevelDebug {
|
||||
l.logMessage(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Infof logs something
|
||||
func (l *defaultLogger) Infof(format string, args ...interface{}) {
|
||||
if l.logLevel >= LogLevelInfo {
|
||||
l.logMessage(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Errorf logs something
|
||||
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
if l.logLevel >= LogLevelError {
|
||||
l.logMessage(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *defaultLogger) logMessage(format string, args ...interface{}) {
|
||||
var pre string
|
||||
|
||||
if len(l.timeFormat) > 0 {
|
||||
pre = time.Now().Format(l.timeFormat) + " "
|
||||
}
|
||||
if len(l.prefix) > 0 {
|
||||
pre += l.prefix + " "
|
||||
}
|
||||
log.Printf(pre+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) WithPrefix(prefix string) Logger {
|
||||
if len(l.prefix) > 0 {
|
||||
prefix = l.prefix + " " + prefix
|
||||
}
|
||||
return &defaultLogger{
|
||||
logLevel: l.logLevel,
|
||||
timeFormat: l.timeFormat,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Debug returns true if the log level is LogLevelDebug
|
||||
func (l *defaultLogger) Debug() bool {
|
||||
return l.logLevel == LogLevelDebug
|
||||
}
|
||||
|
||||
func init() {
|
||||
DefaultLogger = &defaultLogger{}
|
||||
DefaultLogger.SetLogLevel(readLoggingEnv())
|
||||
}
|
||||
|
||||
func readLoggingEnv() LogLevel {
|
||||
switch strings.ToLower(os.Getenv(logEnv)) {
|
||||
case "":
|
||||
return LogLevelNothing
|
||||
case "debug":
|
||||
return LogLevelDebug
|
||||
case "info":
|
||||
return LogLevelInfo
|
||||
case "error":
|
||||
return LogLevelError
|
||||
default:
|
||||
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/quic-go/quic-go/wiki/Logging")
|
||||
return LogLevelNothing
|
||||
}
|
||||
}
|
72
vendor/github.com/quic-go/quic-go/internal/utils/minmax.go
generated
vendored
Normal file
72
vendor/github.com/quic-go/quic-go/internal/utils/minmax.go
generated
vendored
Normal file
|
@ -0,0 +1,72 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// InfDuration is a duration of infinite length
|
||||
const InfDuration = time.Duration(math.MaxInt64)
|
||||
|
||||
func Max[T constraints.Ordered](a, b T) T {
|
||||
if a < b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func Min[T constraints.Ordered](a, b T) T {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// MinNonZeroDuration return the minimum duration that's not zero.
|
||||
func MinNonZeroDuration(a, b time.Duration) time.Duration {
|
||||
if a == 0 {
|
||||
return b
|
||||
}
|
||||
if b == 0 {
|
||||
return a
|
||||
}
|
||||
return Min(a, b)
|
||||
}
|
||||
|
||||
// AbsDuration returns the absolute value of a time duration
|
||||
func AbsDuration(d time.Duration) time.Duration {
|
||||
if d >= 0 {
|
||||
return d
|
||||
}
|
||||
return -d
|
||||
}
|
||||
|
||||
// MinTime returns the earlier time
|
||||
func MinTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// MinNonZeroTime returns the earlist time that is not time.Time{}
|
||||
// If both a and b are time.Time{}, it returns time.Time{}
|
||||
func MinNonZeroTime(a, b time.Time) time.Time {
|
||||
if a.IsZero() {
|
||||
return b
|
||||
}
|
||||
if b.IsZero() {
|
||||
return a
|
||||
}
|
||||
return MinTime(a, b)
|
||||
}
|
||||
|
||||
// MaxTime returns the later time
|
||||
func MaxTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
29
vendor/github.com/quic-go/quic-go/internal/utils/rand.go
generated
vendored
Normal file
29
vendor/github.com/quic-go/quic-go/internal/utils/rand.go
generated
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand.
|
||||
type Rand struct {
|
||||
buf [4]byte
|
||||
}
|
||||
|
||||
func (r *Rand) Int31() int32 {
|
||||
rand.Read(r.buf[:])
|
||||
return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31))
|
||||
}
|
||||
|
||||
// copied from the standard library math/rand implementation of Int63n
|
||||
func (r *Rand) Int31n(n int32) int32 {
|
||||
if n&(n-1) == 0 { // n is power of two, can mask
|
||||
return r.Int31() & (n - 1)
|
||||
}
|
||||
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
|
||||
v := r.Int31()
|
||||
for v > max {
|
||||
v = r.Int31()
|
||||
}
|
||||
return v % n
|
||||
}
|
127
vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go
generated
vendored
Normal file
127
vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go
generated
vendored
Normal file
|
@ -0,0 +1,127 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
rttAlpha = 0.125
|
||||
oneMinusAlpha = 1 - rttAlpha
|
||||
rttBeta = 0.25
|
||||
oneMinusBeta = 1 - rttBeta
|
||||
// The default RTT used before an RTT sample is taken.
|
||||
defaultInitialRTT = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// RTTStats provides round-trip statistics
|
||||
type RTTStats struct {
|
||||
hasMeasurement bool
|
||||
|
||||
minRTT time.Duration
|
||||
latestRTT time.Duration
|
||||
smoothedRTT time.Duration
|
||||
meanDeviation time.Duration
|
||||
|
||||
maxAckDelay time.Duration
|
||||
}
|
||||
|
||||
// NewRTTStats makes a properly initialized RTTStats object
|
||||
func NewRTTStats() *RTTStats {
|
||||
return &RTTStats{}
|
||||
}
|
||||
|
||||
// MinRTT Returns the minRTT for the entire connection.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
|
||||
|
||||
// LatestRTT returns the most recent rtt measurement.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
|
||||
|
||||
// SmoothedRTT returns the smoothed RTT for the connection.
|
||||
// May return Zero if no valid updates have occurred.
|
||||
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
|
||||
|
||||
// MeanDeviation gets the mean deviation
|
||||
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
|
||||
|
||||
// MaxAckDelay gets the max_ack_delay advertised by the peer
|
||||
func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay }
|
||||
|
||||
// PTO gets the probe timeout duration.
|
||||
func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration {
|
||||
if r.SmoothedRTT() == 0 {
|
||||
return 2 * defaultInitialRTT
|
||||
}
|
||||
pto := r.SmoothedRTT() + Max(4*r.MeanDeviation(), protocol.TimerGranularity)
|
||||
if includeMaxAckDelay {
|
||||
pto += r.MaxAckDelay()
|
||||
}
|
||||
return pto
|
||||
}
|
||||
|
||||
// UpdateRTT updates the RTT based on a new sample.
|
||||
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
||||
if sendDelta == InfDuration || sendDelta <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
|
||||
// ackDelay but the raw observed sendDelta, since poor clock granularity at
|
||||
// the client may cause a high ackDelay to result in underestimation of the
|
||||
// r.minRTT.
|
||||
if r.minRTT == 0 || r.minRTT > sendDelta {
|
||||
r.minRTT = sendDelta
|
||||
}
|
||||
|
||||
// Correct for ackDelay if information received from the peer results in a
|
||||
// an RTT sample at least as large as minRTT. Otherwise, only use the
|
||||
// sendDelta.
|
||||
sample := sendDelta
|
||||
if sample-r.minRTT >= ackDelay {
|
||||
sample -= ackDelay
|
||||
}
|
||||
r.latestRTT = sample
|
||||
// First time call.
|
||||
if !r.hasMeasurement {
|
||||
r.hasMeasurement = true
|
||||
r.smoothedRTT = sample
|
||||
r.meanDeviation = sample / 2
|
||||
} else {
|
||||
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
|
||||
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAckDelay sets the max_ack_delay
|
||||
func (r *RTTStats) SetMaxAckDelay(mad time.Duration) {
|
||||
r.maxAckDelay = mad
|
||||
}
|
||||
|
||||
// SetInitialRTT sets the initial RTT.
|
||||
// It is used during the 0-RTT handshake when restoring the RTT stats from the session state.
|
||||
func (r *RTTStats) SetInitialRTT(t time.Duration) {
|
||||
if r.hasMeasurement {
|
||||
panic("initial RTT set after first measurement")
|
||||
}
|
||||
r.smoothedRTT = t
|
||||
r.latestRTT = t
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
|
||||
func (r *RTTStats) OnConnectionMigration() {
|
||||
r.latestRTT = 0
|
||||
r.minRTT = 0
|
||||
r.smoothedRTT = 0
|
||||
r.meanDeviation = 0
|
||||
}
|
||||
|
||||
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
|
||||
// is larger. The mean deviation is increased to the most recent deviation if
|
||||
// it's larger.
|
||||
func (r *RTTStats) ExpireSmoothedMetrics() {
|
||||
r.meanDeviation = Max(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT))
|
||||
r.smoothedRTT = Max(r.smoothedRTT, r.latestRTT)
|
||||
}
|
57
vendor/github.com/quic-go/quic-go/internal/utils/timer.go
generated
vendored
Normal file
57
vendor/github.com/quic-go/quic-go/internal/utils/timer.go
generated
vendored
Normal file
|
@ -0,0 +1,57 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Timer wrapper that behaves correctly when resetting
|
||||
type Timer struct {
|
||||
t *time.Timer
|
||||
read bool
|
||||
deadline time.Time
|
||||
}
|
||||
|
||||
// NewTimer creates a new timer that is not set
|
||||
func NewTimer() *Timer {
|
||||
return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))}
|
||||
}
|
||||
|
||||
// Chan returns the channel of the wrapped timer
|
||||
func (t *Timer) Chan() <-chan time.Time {
|
||||
return t.t.C
|
||||
}
|
||||
|
||||
// Reset the timer, no matter whether the value was read or not
|
||||
func (t *Timer) Reset(deadline time.Time) {
|
||||
if deadline.Equal(t.deadline) && !t.read {
|
||||
// No need to reset the timer
|
||||
return
|
||||
}
|
||||
|
||||
// We need to drain the timer if the value from its channel was not read yet.
|
||||
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
|
||||
if !t.t.Stop() && !t.read {
|
||||
<-t.t.C
|
||||
}
|
||||
if !deadline.IsZero() {
|
||||
t.t.Reset(time.Until(deadline))
|
||||
}
|
||||
|
||||
t.read = false
|
||||
t.deadline = deadline
|
||||
}
|
||||
|
||||
// SetRead should be called after the value from the chan was read
|
||||
func (t *Timer) SetRead() {
|
||||
t.read = true
|
||||
}
|
||||
|
||||
func (t *Timer) Deadline() time.Time {
|
||||
return t.deadline
|
||||
}
|
||||
|
||||
// Stop stops the timer
|
||||
func (t *Timer) Stop() {
|
||||
t.t.Stop()
|
||||
}
|
251
vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go
generated
vendored
Normal file
251
vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go
generated
vendored
Normal file
|
@ -0,0 +1,251 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
|
||||
|
||||
// An AckFrame is an ACK frame
|
||||
type AckFrame struct {
|
||||
AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
|
||||
DelayTime time.Duration
|
||||
|
||||
ECT0, ECT1, ECNCE uint64
|
||||
}
|
||||
|
||||
// parseAckFrame reads an ACK frame
|
||||
func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) {
|
||||
typeByte, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ecn := typeByte&0x1 > 0
|
||||
|
||||
frame := GetAckFrame()
|
||||
|
||||
la, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
largestAcked := protocol.PacketNumber(la)
|
||||
delay, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
|
||||
if delayTime < 0 {
|
||||
// If the delay time overflows, set it to the maximum encodable value.
|
||||
delayTime = utils.InfDuration
|
||||
}
|
||||
frame.DelayTime = delayTime
|
||||
|
||||
numBlocks, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// read the first ACK range
|
||||
ab, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ackBlock := protocol.PacketNumber(ab)
|
||||
if ackBlock > largestAcked {
|
||||
return nil, errors.New("invalid first ACK range")
|
||||
}
|
||||
smallest := largestAcked - ackBlock
|
||||
|
||||
// read all the other ACK ranges
|
||||
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
|
||||
for i := uint64(0); i < numBlocks; i++ {
|
||||
g, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gap := protocol.PacketNumber(g)
|
||||
if smallest < gap+2 {
|
||||
return nil, errInvalidAckRanges
|
||||
}
|
||||
largest := smallest - gap - 2
|
||||
|
||||
ab, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ackBlock := protocol.PacketNumber(ab)
|
||||
|
||||
if ackBlock > largest {
|
||||
return nil, errInvalidAckRanges
|
||||
}
|
||||
smallest = largest - ackBlock
|
||||
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
|
||||
}
|
||||
|
||||
if !frame.validateAckRanges() {
|
||||
return nil, errInvalidAckRanges
|
||||
}
|
||||
|
||||
// parse (and skip) the ECN section
|
||||
if ecn {
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := quicvarint.Read(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
// Append appends an ACK frame.
|
||||
func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) {
|
||||
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
|
||||
if hasECN {
|
||||
b = append(b, 0b11)
|
||||
} else {
|
||||
b = append(b, 0b10)
|
||||
}
|
||||
b = quicvarint.Append(b, uint64(f.LargestAcked()))
|
||||
b = quicvarint.Append(b, encodeAckDelay(f.DelayTime))
|
||||
|
||||
numRanges := f.numEncodableAckRanges()
|
||||
b = quicvarint.Append(b, uint64(numRanges-1))
|
||||
|
||||
// write the first range
|
||||
_, firstRange := f.encodeAckRange(0)
|
||||
b = quicvarint.Append(b, firstRange)
|
||||
|
||||
// write all the other range
|
||||
for i := 1; i < numRanges; i++ {
|
||||
gap, len := f.encodeAckRange(i)
|
||||
b = quicvarint.Append(b, gap)
|
||||
b = quicvarint.Append(b, len)
|
||||
}
|
||||
|
||||
if hasECN {
|
||||
b = quicvarint.Append(b, f.ECT0)
|
||||
b = quicvarint.Append(b, f.ECT1)
|
||||
b = quicvarint.Append(b, f.ECNCE)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
|
||||
largestAcked := f.AckRanges[0].Largest
|
||||
numRanges := f.numEncodableAckRanges()
|
||||
|
||||
length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime))
|
||||
|
||||
length += quicvarint.Len(uint64(numRanges - 1))
|
||||
lowestInFirstRange := f.AckRanges[0].Smallest
|
||||
length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange))
|
||||
|
||||
for i := 1; i < numRanges; i++ {
|
||||
gap, len := f.encodeAckRange(i)
|
||||
length += quicvarint.Len(gap)
|
||||
length += quicvarint.Len(len)
|
||||
}
|
||||
if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 {
|
||||
length += quicvarint.Len(f.ECT0)
|
||||
length += quicvarint.Len(f.ECT1)
|
||||
length += quicvarint.Len(f.ECNCE)
|
||||
}
|
||||
return length
|
||||
}
|
||||
|
||||
// gets the number of ACK ranges that can be encoded
|
||||
// such that the resulting frame is smaller than the maximum ACK frame size
|
||||
func (f *AckFrame) numEncodableAckRanges() int {
|
||||
length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime))
|
||||
length += 2 // assume that the number of ranges will consume 2 bytes
|
||||
for i := 1; i < len(f.AckRanges); i++ {
|
||||
gap, len := f.encodeAckRange(i)
|
||||
rangeLen := quicvarint.Len(gap) + quicvarint.Len(len)
|
||||
if length+rangeLen > protocol.MaxAckFrameSize {
|
||||
// Writing range i would exceed the MaxAckFrameSize.
|
||||
// So encode one range less than that.
|
||||
return i - 1
|
||||
}
|
||||
length += rangeLen
|
||||
}
|
||||
return len(f.AckRanges)
|
||||
}
|
||||
|
||||
func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) {
|
||||
if i == 0 {
|
||||
return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest)
|
||||
}
|
||||
return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2),
|
||||
uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest)
|
||||
}
|
||||
|
||||
// HasMissingRanges returns if this frame reports any missing packets
|
||||
func (f *AckFrame) HasMissingRanges() bool {
|
||||
return len(f.AckRanges) > 1
|
||||
}
|
||||
|
||||
func (f *AckFrame) validateAckRanges() bool {
|
||||
if len(f.AckRanges) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// check the validity of every single ACK range
|
||||
for _, ackRange := range f.AckRanges {
|
||||
if ackRange.Smallest > ackRange.Largest {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// check the consistency for ACK with multiple NACK ranges
|
||||
for i, ackRange := range f.AckRanges {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
lastAckRange := f.AckRanges[i-1]
|
||||
if lastAckRange.Smallest <= ackRange.Smallest {
|
||||
return false
|
||||
}
|
||||
if lastAckRange.Smallest <= ackRange.Largest+1 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// LargestAcked is the largest acked packet number
|
||||
func (f *AckFrame) LargestAcked() protocol.PacketNumber {
|
||||
return f.AckRanges[0].Largest
|
||||
}
|
||||
|
||||
// LowestAcked is the lowest acked packet number
|
||||
func (f *AckFrame) LowestAcked() protocol.PacketNumber {
|
||||
return f.AckRanges[len(f.AckRanges)-1].Smallest
|
||||
}
|
||||
|
||||
// AcksPacket determines if this ACK frame acks a certain packet number
|
||||
func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool {
|
||||
if p < f.LowestAcked() || p > f.LargestAcked() {
|
||||
return false
|
||||
}
|
||||
|
||||
i := sort.Search(len(f.AckRanges), func(i int) bool {
|
||||
return p >= f.AckRanges[i].Smallest
|
||||
})
|
||||
// i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked
|
||||
return p <= f.AckRanges[i].Largest
|
||||
}
|
||||
|
||||
func encodeAckDelay(delay time.Duration) uint64 {
|
||||
return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent)))
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue