Update deps

This commit is contained in:
Frank Denis 2024-06-13 23:44:17 +02:00
parent 7a4b2ac7ea
commit 2dd6c8e996
104 changed files with 5055 additions and 1906 deletions

View file

@ -188,6 +188,9 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
* 8777 - DNS Reverse IP Automatic Multicast Tunneling (AMT) Discovery
* 8914 - Extended DNS Errors
* 8976 - Message Digest for DNS Zones (ZONEMD RR)
* 9460 - Service Binding and Parameter Specification via the DNS
* 9461 - Service Binding Mapping for DNS Servers
* 9462 - Discovery of Designated Resolvers
## Loosely Based Upon

View file

@ -55,7 +55,10 @@ func endingToTxtSlice(c *zlexer, errstr string) ([]string, *ParseError) {
sx := []string{}
p := 0
for {
i := escapedStringOffset(l.token[p:], 255)
i, ok := escapedStringOffset(l.token[p:], 255)
if !ok {
return nil, &ParseError{err: errstr, lex: l}
}
if i != -1 && p+i != len(l.token) {
sx = append(sx, l.token[p:p+i])
} else {
@ -1919,29 +1922,36 @@ func (rr *APL) parse(c *zlexer, o string) *ParseError {
// escapedStringOffset finds the offset within a string (which may contain escape
// sequences) that corresponds to a certain byte offset. If the input offset is
// out of bounds, -1 is returned.
func escapedStringOffset(s string, byteOffset int) int {
if byteOffset == 0 {
return 0
// out of bounds, -1 is returned (which is *not* considered an error).
func escapedStringOffset(s string, desiredByteOffset int) (int, bool) {
if desiredByteOffset == 0 {
return 0, true
}
offset := 0
for i := 0; i < len(s); i++ {
offset += 1
currentByteOffset, i := 0, 0
for i < len(s) {
currentByteOffset += 1
// Skip escape sequences
if s[i] != '\\' {
// Not an escape sequence; nothing to do.
} else if isDDD(s[i+1:]) {
i += 3
} else {
// Single plain byte, not an escape sequence.
i++
} else if isDDD(s[i+1:]) {
// Skip backslash and DDD.
i += 4
} else if len(s[i+1:]) < 1 {
// No character following the backslash; that's an error.
return 0, false
} else {
// Skip backslash and following byte.
i += 2
}
if offset >= byteOffset {
return i + 1
if currentByteOffset >= desiredByteOffset {
return i, true
}
}
return -1
return -1, true
}

View file

@ -188,6 +188,14 @@ type DecorateReader func(Reader) Reader
// Implementations should never return a nil Writer.
type DecorateWriter func(Writer) Writer
// MsgInvalidFunc is a listener hook for observing incoming messages that were discarded
// because they could not be parsed.
// Every message that is read by a Reader will eventually be provided to the Handler,
// rejected (or ignored) by the MsgAcceptFunc, or passed to this function.
type MsgInvalidFunc func(m []byte, err error)
func DefaultMsgInvalidFunc(m []byte, err error) {}
// A Server defines parameters for running an DNS server.
type Server struct {
// Address to listen on, ":dns" if empty.
@ -233,6 +241,8 @@ type Server struct {
// AcceptMsgFunc will check the incoming message and will reject it early in the process.
// By default DefaultMsgAcceptFunc will be used.
MsgAcceptFunc MsgAcceptFunc
// MsgInvalidFunc is optional, will be called if a message is received but cannot be parsed.
MsgInvalidFunc MsgInvalidFunc
// Shutdown handling
lock sync.RWMutex
@ -277,6 +287,9 @@ func (srv *Server) init() {
if srv.MsgAcceptFunc == nil {
srv.MsgAcceptFunc = DefaultMsgAcceptFunc
}
if srv.MsgInvalidFunc == nil {
srv.MsgInvalidFunc = DefaultMsgInvalidFunc
}
if srv.Handler == nil {
srv.Handler = DefaultServeMux
}
@ -531,6 +544,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error {
if cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
srv.MsgInvalidFunc(m, ErrShortRead)
continue
}
wg.Add(1)
@ -611,6 +625,7 @@ func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn
func (srv *Server) serveDNS(m []byte, w *response) {
dh, off, err := unpackMsgHdr(m, 0)
if err != nil {
srv.MsgInvalidFunc(m, err)
// Let client hang, they are sending crap; any reply can be used to amplify.
return
}
@ -620,10 +635,12 @@ func (srv *Server) serveDNS(m []byte, w *response) {
switch action := srv.MsgAcceptFunc(dh); action {
case MsgAccept:
if req.unpack(dh, m, off) == nil {
err := req.unpack(dh, m, off)
if err == nil {
break
}
srv.MsgInvalidFunc(m, err)
fallthrough
case MsgReject, MsgRejectNotImplemented:
opcode := req.Opcode

50
vendor/github.com/miekg/dns/svcb.go generated vendored
View file

@ -14,7 +14,7 @@ import (
// SVCBKey is the type of the keys used in the SVCB RR.
type SVCBKey uint16
// Keys defined in draft-ietf-dnsop-svcb-https-08 Section 14.3.2.
// Keys defined in rfc9460
const (
SVCB_MANDATORY SVCBKey = iota
SVCB_ALPN
@ -23,7 +23,8 @@ const (
SVCB_IPV4HINT
SVCB_ECHCONFIG
SVCB_IPV6HINT
SVCB_DOHPATH // draft-ietf-add-svcb-dns-02 Section 9
SVCB_DOHPATH // rfc9461 Section 5
SVCB_OHTTP // rfc9540 Section 8
svcb_RESERVED SVCBKey = 65535
)
@ -37,6 +38,7 @@ var svcbKeyToStringMap = map[SVCBKey]string{
SVCB_ECHCONFIG: "ech",
SVCB_IPV6HINT: "ipv6hint",
SVCB_DOHPATH: "dohpath",
SVCB_OHTTP: "ohttp",
}
var svcbStringToKeyMap = reverseSVCBKeyMap(svcbKeyToStringMap)
@ -201,6 +203,8 @@ func makeSVCBKeyValue(key SVCBKey) SVCBKeyValue {
return new(SVCBIPv6Hint)
case SVCB_DOHPATH:
return new(SVCBDoHPath)
case SVCB_OHTTP:
return new(SVCBOhttp)
case svcb_RESERVED:
return nil
default:
@ -771,8 +775,8 @@ func (s *SVCBIPv6Hint) copy() SVCBKeyValue {
// SVCBDoHPath pair is used to indicate the URI template that the
// clients may use to construct a DNS over HTTPS URI.
//
// See RFC xxxx (https://datatracker.ietf.org/doc/html/draft-ietf-add-svcb-dns-02)
// and RFC yyyy (https://datatracker.ietf.org/doc/html/draft-ietf-add-ddr-06).
// See RFC 9461 (https://datatracker.ietf.org/doc/html/rfc9461)
// and RFC 9462 (https://datatracker.ietf.org/doc/html/rfc9462).
//
// A basic example of using the dohpath option together with the alpn
// option to indicate support for DNS over HTTPS on a certain path:
@ -816,6 +820,44 @@ func (s *SVCBDoHPath) copy() SVCBKeyValue {
}
}
// The "ohttp" SvcParamKey is used to indicate that a service described in a SVCB RR
// can be accessed as a target using an associated gateway.
// Both the presentation and wire-format values for the "ohttp" parameter MUST be empty.
//
// See RFC 9460 (https://datatracker.ietf.org/doc/html/rfc9460/)
// and RFC 9230 (https://datatracker.ietf.org/doc/html/rfc9230/)
//
// A basic example of using the dohpath option together with the alpn
// option to indicate support for DNS over HTTPS on a certain path:
//
// s := new(dns.SVCB)
// s.Hdr = dns.RR_Header{Name: ".", Rrtype: dns.TypeSVCB, Class: dns.ClassINET}
// e := new(dns.SVCBAlpn)
// e.Alpn = []string{"h2", "h3"}
// p := new(dns.SVCBOhttp)
// s.Value = append(s.Value, e, p)
type SVCBOhttp struct{}
func (*SVCBOhttp) Key() SVCBKey { return SVCB_OHTTP }
func (*SVCBOhttp) copy() SVCBKeyValue { return &SVCBOhttp{} }
func (*SVCBOhttp) pack() ([]byte, error) { return []byte{}, nil }
func (*SVCBOhttp) String() string { return "" }
func (*SVCBOhttp) len() int { return 0 }
func (*SVCBOhttp) unpack(b []byte) error {
if len(b) != 0 {
return errors.New("dns: svcbotthp: svcbotthp must have no value")
}
return nil
}
func (*SVCBOhttp) parse(b string) error {
if b != "" {
return errors.New("dns: svcbotthp: svcbotthp must have no value")
}
return nil
}
// SVCBLocal pair is intended for experimental/private use. The key is recommended
// to be in the range [SVCB_PRIVATE_LOWER, SVCB_PRIVATE_UPPER].
// Basic use pattern for creating a keyNNNNN option:

1
vendor/github.com/miekg/dns/xfr.go generated vendored
View file

@ -209,6 +209,7 @@ func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) {
// ch := make(chan *dns.Envelope)
// tr := new(dns.Transfer)
// var wg sync.WaitGroup
// wg.Add(1)
// go func() {
// tr.Out(w, r, ch)
// wg.Done()

View file

@ -191,6 +191,7 @@ func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection(
context.WithValue(context.WithoutCancel(ctx), ConnectionTracingKey, c.tracingID),
c.sendConn,
c.packetHandlers,
c.destConnID,
@ -202,7 +203,6 @@ func (c *client) dial(ctx context.Context) error {
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
)

View file

@ -52,7 +52,7 @@ type streamManager interface {
}
type cryptoStreamHandler interface {
StartHandshake() error
StartHandshake(context.Context) error
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
@ -169,10 +169,9 @@ type connection struct {
// closeChan is used to notify the run loop that it should terminate
closeChan chan closeError
ctx context.Context
ctxCancel context.CancelCauseFunc
handshakeCtx context.Context
handshakeCtxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelCauseFunc
handshakeCompleteChan chan struct{}
undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []receivedPacket
@ -222,6 +221,8 @@ var (
)
var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
origDestConnID protocol.ConnectionID,
@ -236,11 +237,12 @@ var newConnection = func(
tokenGenerator *handshake.TokenGenerator,
clientAddressValidated bool,
tracer *logging.ConnectionTracer,
tracingID ConnectionTracingID,
logger utils.Logger,
v protocol.Version,
) quicConn {
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
@ -274,7 +276,6 @@ var newConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
@ -339,6 +340,7 @@ var newConnection = func(
// declare this as a variable, such that we can it mock it in the tests
var newClientConnection = func(
ctx context.Context,
conn sendConn,
runner connRunner,
destConnID protocol.ConnectionID,
@ -350,7 +352,6 @@ var newClientConnection = func(
enable0RTT bool,
hasNegotiatedVersion bool,
tracer *logging.ConnectionTracer,
tracingID ConnectionTracingID,
logger utils.Logger,
v protocol.Version,
) quicConn {
@ -384,7 +385,7 @@ var newClientConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
@ -486,7 +487,7 @@ func (s *connection) preSetup() {
s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background())
s.handshakeCompleteChan = make(chan struct{})
now := time.Now()
s.lastPacketReceivedTime = now
@ -500,13 +501,11 @@ func (s *connection) preSetup() {
// run the connection main loop
func (s *connection) run() error {
var closeErr closeError
defer func() {
s.ctxCancel(closeErr.err)
}()
defer func() { s.ctxCancel(closeErr.err) }()
s.timer = *newTimer()
if err := s.cryptoStreamHandler.StartHandshake(); err != nil {
if err := s.cryptoStreamHandler.StartHandshake(s.ctx); err != nil {
return err
}
if err := s.handleHandshakeEvents(); err != nil {
@ -667,7 +666,7 @@ func (s *connection) earlyConnReady() <-chan struct{} {
}
func (s *connection) HandshakeComplete() <-chan struct{} {
return s.handshakeCtx.Done()
return s.handshakeCompleteChan
}
func (s *connection) Context() context.Context {
@ -732,7 +731,7 @@ func (s *connection) idleTimeoutStartTime() time.Time {
}
func (s *connection) handleHandshakeComplete() error {
defer s.handshakeCtxCancel()
defer close(s.handshakeCompleteChan)
// Once the handshake completes, we have derived 1-RTT keys.
// There's no point in queueing undecryptable packets for later decryption anymore.
s.undecryptablePackets = nil
@ -2425,10 +2424,17 @@ func (s *connection) GetVersion() protocol.Version {
return s.version
}
func (s *connection) NextConnection() Connection {
<-s.HandshakeComplete()
s.streamsMap.UseResetMaps()
return s
func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
// The handshake might fail after the server rejected 0-RTT.
// This could happen if the Finished message is malformed or never received.
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-s.Context().Done():
case <-s.HandshakeComplete():
s.streamsMap.UseResetMaps()
}
return s, nil
}
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.

View file

@ -82,7 +82,13 @@ func (c *SingleDestinationRoundTripper) Start() Connection {
func (c *SingleDestinationRoundTripper) init() {
c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
c.requestWriter = newRequestWriter()
c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger)
c.hconn = newConnection(
c.Connection.Context(),
c.Connection,
c.EnableDatagrams,
protocol.PerspectiveClient,
c.Logger,
)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(c.hconn); err != nil {

View file

@ -37,6 +37,7 @@ type Connection interface {
type connection struct {
quic.Connection
ctx context.Context
perspective protocol.Perspective
logger *slog.Logger
@ -53,12 +54,14 @@ type connection struct {
}
func newConnection(
ctx context.Context,
quicConn quic.Connection,
enableDatagrams bool,
perspective protocol.Perspective,
logger *slog.Logger,
) *connection {
c := &connection{
return &connection{
ctx: ctx,
Connection: quicConn,
perspective: perspective,
logger: logger,
@ -67,7 +70,6 @@ func newConnection(
receivedSettings: make(chan struct{}),
streams: make(map[protocol.StreamID]*datagrammer),
}
return c
}
func (c *connection) clearStream(id quic.StreamID) {
@ -264,3 +266,5 @@ func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSetti
// Settings returns the settings received on this connection.
// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
func (c *connection) Settings() *Settings { return c.settings }
func (c *connection) Context() context.Context { return c.ctx }

View file

@ -94,6 +94,10 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
if config == nil {
return nil, nil
}
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = config.DecryptTicket(nil, tls.ConnectionState{})
config = config.Clone()
config.NextProtos = []string{proto}
return config, nil
@ -194,9 +198,8 @@ type Server struct {
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
// ConnContext optionally specifies a function that modifies
// the context used for a new connection c. The provided ctx
// has a ServerContextKey value.
// ConnContext optionally specifies a function that modifies the context used for a new connection c.
// The provided ctx has a ServerContextKey value.
ConnContext func(ctx context.Context, c quic.Connection) context.Context
Logger *slog.Logger
@ -436,7 +439,19 @@ func (s *Server) handleConn(conn quic.Connection) error {
}).Append(b)
str.Write(b)
ctx := conn.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}
hconn := newConnection(
ctx,
conn,
s.EnableDatagrams,
protocol.PerspectiveServer,
@ -533,17 +548,10 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
s.Logger.Debug("handling request", "method", req.Method, "host", req.Host, "uri", req.RequestURI)
}
ctx := str.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn.Connection)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}
ctx, cancel := context.WithCancel(conn.Context())
req = req.WithContext(ctx)
context.AfterFunc(str.Context(), cancel)
r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, s.Logger)
handler := s.Handler
if handler == nil {

View file

@ -57,9 +57,11 @@ var Err0RTTRejected = errors.New("0-RTT rejected")
// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection.
// It is set on the Connection.Context() context,
// as well as on the context passed to logging.Tracer.NewConnectionTracer.
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
var ConnectionTracingKey = connTracingCtxKey{}
// ConnectionTracingID is the type of the context value saved under the ConnectionTracingKey.
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
type ConnectionTracingID uint64
type connTracingCtxKey struct{}
@ -222,7 +224,7 @@ type EarlyConnection interface {
// however the client's identity is only verified once the handshake completes.
HandshakeComplete() <-chan struct{}
NextConnection() Connection
NextConnection(context.Context) (Connection, error)
}
// StatelessResetKey is a key used to derive stateless reset tokens.
@ -334,7 +336,6 @@ type Config struct {
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
// If unavailable or disabled, packets will be at most 1280 bytes in size.
DisablePathMTUDiscovery bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// Only valid for the server.

View file

@ -123,44 +123,12 @@ func NewCryptoSetupServer(
)
cs.allow0RTT = allow0RTT
quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
cs.tlsConf = quicConf.TLSConfig
cs.conn = tls.QUICServer(quicConf)
tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
cs.tlsConf = tlsConf
cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
return cs
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
c = c.Clone()
// This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted.
c.MinVersion = tls.VersionTLS13
// We're returning a tls.Config here, so we need to apply this recursively.
addConnToClientHelloInfo(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
}
func newCryptoSetup(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
@ -203,8 +171,8 @@ func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) StartHandshake() error {
err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
err := h.conn.Start(context.WithValue(ctx, QUICVersionContextKey, h.version))
if err != nil {
return wrapError(err)
}
@ -376,9 +344,7 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
EarlyData: h.allow0RTT,
}); err != nil {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.

View file

@ -1,6 +1,7 @@
package handshake
import (
"context"
"crypto/tls"
"errors"
"io"
@ -91,7 +92,7 @@ type Event struct {
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
StartHandshake() error
StartHandshake(context.Context) error
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)

View file

@ -1,4 +1,4 @@
package handshake
package qtls
import (
"net"

View file

@ -4,20 +4,23 @@ import (
"bytes"
"crypto/tls"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
conf := qconf.TLSConfig
func SetupConfigForServer(
conf *tls.Config,
localAddr, remoteAddr net.Addr,
getData func() []byte,
handleSessionTicket func([]byte, bool) bool,
) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
qconf.TLSConfig = conf
// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
@ -58,6 +61,29 @@ func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte,
return state, nil
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// We're returning a tls.Config here, so we need to apply this recursively.
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}
func SetupConfigForClient(

View file

@ -25,16 +25,80 @@ const (
maxMTUDiff = 20
// send a probe packet every mtuProbeDelay RTTs
mtuProbeDelay = 5
// Once maxLostMTUProbes MTU probe packets larger than a certain size are lost,
// MTU discovery won't probe for larger MTUs than this size.
// The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets.
maxLostMTUProbes = 3
)
// The Path MTU is found by sending a larger packet every now and then.
// If the packet is acknowledged, we conclude that the path supports this larger packet size.
// If the packet is lost, this can mean one of two things:
// 1. The path doesn't support this larger packet size, or
// 2. The packet was lost due to packet loss, independent of its size.
// The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets.
// For simplicty, the following example use maxLostMTUProbes = 2.
//
// Initialization:
// |------------------------------------------------------------------------------|
// min max
//
// The first MTU probe packet will have size (min+max)/2.
// Assume that this packet is acknowledged. We can now move the min marker,
// and continue the search in the resulting interval.
//
// If 1st probe packet acknowledged:
// |---------------------------------------|--------------------------------------|
// min max
//
// If 1st probe packet lost:
// |---------------------------------------|--------------------------------------|
// min lost[0] max
//
// We can't conclude that the path doesn't support this packet size, since the loss of the probe
// packet could have been unrelated to the packet size. A larger probe packet will be sent later on.
// After a loss, the next probe packet has size (min+lost[0])/2.
// Now assume this probe packet is acknowledged:
//
// 2nd probe packet acknowledged:
// |------------------|--------------------|--------------------------------------|
// min lost[0] max
//
// First of all, we conclude that the path supports at least this MTU. That's progress!
// Second, we probe a bit more aggressively with the next probe packet:
// After an acknowledgement, the next probe packet has size (min+max)/2.
// This means we'll send a packet larger than the first probe packet (which was lost).
//
// If 3rd probe packet acknowledged:
// |-------------------------------------------------|----------------------------|
// min max
//
// We can conclude that the loss of the 1st probe packet was not due to its size, and
// continue searching in a much smaller interval now.
//
// If 3rd probe packet lost:
// |------------------|--------------------|---------|----------------------------|
// min lost[0] max
//
// Since in our example numPTOProbes = 2, and we lost 2 packets smaller than max, we
// conclude that this packet size is not supported on the path, and reduce the maximum
// value of the search interval.
//
// MTU discovery concludes once the interval min and max has been narrowed down to maxMTUDiff.
type mtuFinder struct {
lastProbeTime time.Time
mtuIncreased func(protocol.ByteCount)
rttStats *utils.RTTStats
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
current protocol.ByteCount
max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer)
min protocol.ByteCount
limit protocol.ByteCount
// on initialization, we treat the maximum size as the first "lost" packet
lost [maxLostMTUProbes]protocol.ByteCount
lastProbeWasLost bool
tracer *logging.ConnectionTracer
}
@ -47,33 +111,43 @@ func newMTUDiscoverer(
mtuIncreased func(protocol.ByteCount),
tracer *logging.ConnectionTracer,
) *mtuFinder {
return &mtuFinder{
f := &mtuFinder{
inFlight: protocol.InvalidByteCount,
current: start,
max: max,
min: start,
limit: max,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
tracer: tracer,
}
for i := range f.lost {
if i == 0 {
f.lost[i] = max
continue
}
f.lost[i] = protocol.InvalidByteCount
}
return f
}
func (f *mtuFinder) done() bool {
return f.max-f.current <= maxMTUDiff+1
return f.max()-f.min <= maxMTUDiff+1
}
func (f *mtuFinder) SetMax(max protocol.ByteCount) {
f.max = max
func (f *mtuFinder) max() protocol.ByteCount {
for i, v := range f.lost {
if v == protocol.InvalidByteCount {
return f.lost[i-1]
}
}
return f.lost[len(f.lost)-1]
}
func (f *mtuFinder) Start() {
if f.max == protocol.InvalidByteCount {
panic("invalid")
}
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
}
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
if f.max == 0 || f.lastProbeTime.IsZero() {
if f.lastProbeTime.IsZero() {
return false
}
if f.inFlight != protocol.InvalidByteCount || f.done() {
@ -83,7 +157,12 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
}
func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
size := (f.max + f.current) / 2
var size protocol.ByteCount
if f.lastProbeWasLost {
size = (f.min + f.lost[0]) / 2
} else {
size = (f.min + f.max()) / 2
}
f.lastProbeTime = time.Now()
f.inFlight = size
return ackhandler.Frame{
@ -93,7 +172,7 @@ func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
}
func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.current
return f.min
}
type mtuFinderAckHandler struct {
@ -108,7 +187,25 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
panic("OnAcked callback called although there's no MTU probe packet in flight")
}
h.inFlight = protocol.InvalidByteCount
h.current = size
h.min = size
h.lastProbeWasLost = false
// remove all values smaller than size from the lost array
var j int
for i, v := range h.lost {
if size < v {
j = i
break
}
}
if j > 0 {
for i := 0; i < len(h.lost); i++ {
if i+j < len(h.lost) {
h.lost[i] = h.lost[i+j]
} else {
h.lost[i] = protocol.InvalidByteCount
}
}
}
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
h.tracer.UpdatedMTU(size, h.done())
}
@ -120,6 +217,13 @@ func (h *mtuFinderAckHandler) OnLost(wire.Frame) {
if size == protocol.InvalidByteCount {
panic("OnLost callback called although there's no MTU probe packet in flight")
}
h.max = size
h.lastProbeWasLost = true
h.inFlight = protocol.InvalidByteCount
for i, v := range h.lost {
if size < v {
copy(h.lost[i+1:], h.lost[i:])
h.lost[i] = size
break
}
}
}

View file

@ -478,7 +478,6 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.ctxCancel(err)
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalWrite()

View file

@ -76,8 +76,12 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
connContext func(context.Context) context.Context
// set as a member, so they can be set in the tests
newConn func(
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
@ -92,7 +96,6 @@ type baseServer struct {
*handshake.TokenGenerator,
bool, /* client address validated by an address validation token */
*logging.ConnectionTracer,
ConnectionTracingID,
utils.Logger,
protocol.Version,
) quicConn
@ -231,6 +234,7 @@ func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
connContext func(context.Context) context.Context,
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
@ -243,6 +247,7 @@ func newServer(
) *baseServer {
s := &baseServer{
conn: conn,
connContext: connContext,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
@ -631,7 +636,26 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
var conn quicConn
tracingID := nextConnTracingID()
var cancel context.CancelCauseFunc
ctx, cancel1 := context.WithCancelCause(context.Background())
if s.connContext != nil {
ctx = s.connContext(ctx)
if ctx == nil {
panic("quic: ConnContext returned nil")
}
// There's no guarantee that the application returns a context
// that's derived from the context we passed into ConnContext.
// We need to make sure that both contexts are cancelled.
var cancel2 context.CancelCauseFunc
ctx, cancel2 = context.WithCancelCause(ctx)
cancel = func(cause error) {
cancel1(cause)
cancel2(cause)
}
} else {
cancel = cancel1
}
ctx = context.WithValue(ctx, ConnectionTracingKey, nextConnTracingID())
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
@ -639,7 +663,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
tracer = config.Tracer(ctx, protocol.PerspectiveServer, connID)
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
@ -647,6 +671,8 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,
@ -661,7 +687,6 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.tokenGenerator,
clientAddrVerified,
tracer,
tracingID,
s.logger,
hdr.Version,
)

View file

@ -89,6 +89,17 @@ type Transport struct {
// implementation of this callback (negating its return value).
VerifySourceAddress func(net.Addr) bool
// ConnContext is called when the server accepts a new connection.
// The context is closed when the connection is closed, or when the handshake fails for any reason.
// The context returned from the callback is used to derive every other context used during the
// lifetime of the connection:
// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
// * the context used in Config.Tracer
// * the context returned from Connection.Context
// * the context returned from SendStream.Context
// It is not used for dialed connections.
ConnContext func(context.Context) context.Context
// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Tracer *logging.Tracer
@ -168,6 +179,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
t.conn,
t.handlerMap,
t.connIDGenerator,
t.ConnContext,
tlsConf,
conf,
t.Tracer,