Update deps

This commit is contained in:
Frank Denis 2024-06-13 23:44:17 +02:00
parent 7a4b2ac7ea
commit 2dd6c8e996
104 changed files with 5055 additions and 1906 deletions

View file

@ -191,6 +191,7 @@ func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection(
context.WithValue(context.WithoutCancel(ctx), ConnectionTracingKey, c.tracingID),
c.sendConn,
c.packetHandlers,
c.destConnID,
@ -202,7 +203,6 @@ func (c *client) dial(ctx context.Context) error {
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
)

View file

@ -52,7 +52,7 @@ type streamManager interface {
}
type cryptoStreamHandler interface {
StartHandshake() error
StartHandshake(context.Context) error
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
@ -169,10 +169,9 @@ type connection struct {
// closeChan is used to notify the run loop that it should terminate
closeChan chan closeError
ctx context.Context
ctxCancel context.CancelCauseFunc
handshakeCtx context.Context
handshakeCtxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelCauseFunc
handshakeCompleteChan chan struct{}
undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []receivedPacket
@ -222,6 +221,8 @@ var (
)
var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
origDestConnID protocol.ConnectionID,
@ -236,11 +237,12 @@ var newConnection = func(
tokenGenerator *handshake.TokenGenerator,
clientAddressValidated bool,
tracer *logging.ConnectionTracer,
tracingID ConnectionTracingID,
logger utils.Logger,
v protocol.Version,
) quicConn {
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
@ -274,7 +276,6 @@ var newConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
@ -339,6 +340,7 @@ var newConnection = func(
// declare this as a variable, such that we can it mock it in the tests
var newClientConnection = func(
ctx context.Context,
conn sendConn,
runner connRunner,
destConnID protocol.ConnectionID,
@ -350,7 +352,6 @@ var newClientConnection = func(
enable0RTT bool,
hasNegotiatedVersion bool,
tracer *logging.ConnectionTracer,
tracingID ConnectionTracingID,
logger utils.Logger,
v protocol.Version,
) quicConn {
@ -384,7 +385,7 @@ var newClientConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
@ -486,7 +487,7 @@ func (s *connection) preSetup() {
s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background())
s.handshakeCompleteChan = make(chan struct{})
now := time.Now()
s.lastPacketReceivedTime = now
@ -500,13 +501,11 @@ func (s *connection) preSetup() {
// run the connection main loop
func (s *connection) run() error {
var closeErr closeError
defer func() {
s.ctxCancel(closeErr.err)
}()
defer func() { s.ctxCancel(closeErr.err) }()
s.timer = *newTimer()
if err := s.cryptoStreamHandler.StartHandshake(); err != nil {
if err := s.cryptoStreamHandler.StartHandshake(s.ctx); err != nil {
return err
}
if err := s.handleHandshakeEvents(); err != nil {
@ -667,7 +666,7 @@ func (s *connection) earlyConnReady() <-chan struct{} {
}
func (s *connection) HandshakeComplete() <-chan struct{} {
return s.handshakeCtx.Done()
return s.handshakeCompleteChan
}
func (s *connection) Context() context.Context {
@ -732,7 +731,7 @@ func (s *connection) idleTimeoutStartTime() time.Time {
}
func (s *connection) handleHandshakeComplete() error {
defer s.handshakeCtxCancel()
defer close(s.handshakeCompleteChan)
// Once the handshake completes, we have derived 1-RTT keys.
// There's no point in queueing undecryptable packets for later decryption anymore.
s.undecryptablePackets = nil
@ -2425,10 +2424,17 @@ func (s *connection) GetVersion() protocol.Version {
return s.version
}
func (s *connection) NextConnection() Connection {
<-s.HandshakeComplete()
s.streamsMap.UseResetMaps()
return s
func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
// The handshake might fail after the server rejected 0-RTT.
// This could happen if the Finished message is malformed or never received.
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-s.Context().Done():
case <-s.HandshakeComplete():
s.streamsMap.UseResetMaps()
}
return s, nil
}
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.

View file

@ -82,7 +82,13 @@ func (c *SingleDestinationRoundTripper) Start() Connection {
func (c *SingleDestinationRoundTripper) init() {
c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
c.requestWriter = newRequestWriter()
c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger)
c.hconn = newConnection(
c.Connection.Context(),
c.Connection,
c.EnableDatagrams,
protocol.PerspectiveClient,
c.Logger,
)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(c.hconn); err != nil {

View file

@ -37,6 +37,7 @@ type Connection interface {
type connection struct {
quic.Connection
ctx context.Context
perspective protocol.Perspective
logger *slog.Logger
@ -53,12 +54,14 @@ type connection struct {
}
func newConnection(
ctx context.Context,
quicConn quic.Connection,
enableDatagrams bool,
perspective protocol.Perspective,
logger *slog.Logger,
) *connection {
c := &connection{
return &connection{
ctx: ctx,
Connection: quicConn,
perspective: perspective,
logger: logger,
@ -67,7 +70,6 @@ func newConnection(
receivedSettings: make(chan struct{}),
streams: make(map[protocol.StreamID]*datagrammer),
}
return c
}
func (c *connection) clearStream(id quic.StreamID) {
@ -264,3 +266,5 @@ func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSetti
// Settings returns the settings received on this connection.
// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
func (c *connection) Settings() *Settings { return c.settings }
func (c *connection) Context() context.Context { return c.ctx }

View file

@ -94,6 +94,10 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
if config == nil {
return nil, nil
}
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = config.DecryptTicket(nil, tls.ConnectionState{})
config = config.Clone()
config.NextProtos = []string{proto}
return config, nil
@ -194,9 +198,8 @@ type Server struct {
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
// ConnContext optionally specifies a function that modifies
// the context used for a new connection c. The provided ctx
// has a ServerContextKey value.
// ConnContext optionally specifies a function that modifies the context used for a new connection c.
// The provided ctx has a ServerContextKey value.
ConnContext func(ctx context.Context, c quic.Connection) context.Context
Logger *slog.Logger
@ -436,7 +439,19 @@ func (s *Server) handleConn(conn quic.Connection) error {
}).Append(b)
str.Write(b)
ctx := conn.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}
hconn := newConnection(
ctx,
conn,
s.EnableDatagrams,
protocol.PerspectiveServer,
@ -533,17 +548,10 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
s.Logger.Debug("handling request", "method", req.Method, "host", req.Host, "uri", req.RequestURI)
}
ctx := str.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn.Connection)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}
ctx, cancel := context.WithCancel(conn.Context())
req = req.WithContext(ctx)
context.AfterFunc(str.Context(), cancel)
r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, s.Logger)
handler := s.Handler
if handler == nil {

View file

@ -57,9 +57,11 @@ var Err0RTTRejected = errors.New("0-RTT rejected")
// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection.
// It is set on the Connection.Context() context,
// as well as on the context passed to logging.Tracer.NewConnectionTracer.
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
var ConnectionTracingKey = connTracingCtxKey{}
// ConnectionTracingID is the type of the context value saved under the ConnectionTracingKey.
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
type ConnectionTracingID uint64
type connTracingCtxKey struct{}
@ -222,7 +224,7 @@ type EarlyConnection interface {
// however the client's identity is only verified once the handshake completes.
HandshakeComplete() <-chan struct{}
NextConnection() Connection
NextConnection(context.Context) (Connection, error)
}
// StatelessResetKey is a key used to derive stateless reset tokens.
@ -334,7 +336,6 @@ type Config struct {
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
// If unavailable or disabled, packets will be at most 1280 bytes in size.
DisablePathMTUDiscovery bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// Only valid for the server.

View file

@ -123,44 +123,12 @@ func NewCryptoSetupServer(
)
cs.allow0RTT = allow0RTT
quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
cs.tlsConf = quicConf.TLSConfig
cs.conn = tls.QUICServer(quicConf)
tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
cs.tlsConf = tlsConf
cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
return cs
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
c = c.Clone()
// This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted.
c.MinVersion = tls.VersionTLS13
// We're returning a tls.Config here, so we need to apply this recursively.
addConnToClientHelloInfo(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
}
func newCryptoSetup(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
@ -203,8 +171,8 @@ func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) StartHandshake() error {
err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
err := h.conn.Start(context.WithValue(ctx, QUICVersionContextKey, h.version))
if err != nil {
return wrapError(err)
}
@ -376,9 +344,7 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
EarlyData: h.allow0RTT,
}); err != nil {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.

View file

@ -1,6 +1,7 @@
package handshake
import (
"context"
"crypto/tls"
"errors"
"io"
@ -91,7 +92,7 @@ type Event struct {
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
StartHandshake() error
StartHandshake(context.Context) error
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)

View file

@ -1,4 +1,4 @@
package handshake
package qtls
import (
"net"

View file

@ -4,20 +4,23 @@ import (
"bytes"
"crypto/tls"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
conf := qconf.TLSConfig
func SetupConfigForServer(
conf *tls.Config,
localAddr, remoteAddr net.Addr,
getData func() []byte,
handleSessionTicket func([]byte, bool) bool,
) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
qconf.TLSConfig = conf
// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
@ -58,6 +61,29 @@ func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte,
return state, nil
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// We're returning a tls.Config here, so we need to apply this recursively.
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}
func SetupConfigForClient(

View file

@ -25,16 +25,80 @@ const (
maxMTUDiff = 20
// send a probe packet every mtuProbeDelay RTTs
mtuProbeDelay = 5
// Once maxLostMTUProbes MTU probe packets larger than a certain size are lost,
// MTU discovery won't probe for larger MTUs than this size.
// The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets.
maxLostMTUProbes = 3
)
// The Path MTU is found by sending a larger packet every now and then.
// If the packet is acknowledged, we conclude that the path supports this larger packet size.
// If the packet is lost, this can mean one of two things:
// 1. The path doesn't support this larger packet size, or
// 2. The packet was lost due to packet loss, independent of its size.
// The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets.
// For simplicty, the following example use maxLostMTUProbes = 2.
//
// Initialization:
// |------------------------------------------------------------------------------|
// min max
//
// The first MTU probe packet will have size (min+max)/2.
// Assume that this packet is acknowledged. We can now move the min marker,
// and continue the search in the resulting interval.
//
// If 1st probe packet acknowledged:
// |---------------------------------------|--------------------------------------|
// min max
//
// If 1st probe packet lost:
// |---------------------------------------|--------------------------------------|
// min lost[0] max
//
// We can't conclude that the path doesn't support this packet size, since the loss of the probe
// packet could have been unrelated to the packet size. A larger probe packet will be sent later on.
// After a loss, the next probe packet has size (min+lost[0])/2.
// Now assume this probe packet is acknowledged:
//
// 2nd probe packet acknowledged:
// |------------------|--------------------|--------------------------------------|
// min lost[0] max
//
// First of all, we conclude that the path supports at least this MTU. That's progress!
// Second, we probe a bit more aggressively with the next probe packet:
// After an acknowledgement, the next probe packet has size (min+max)/2.
// This means we'll send a packet larger than the first probe packet (which was lost).
//
// If 3rd probe packet acknowledged:
// |-------------------------------------------------|----------------------------|
// min max
//
// We can conclude that the loss of the 1st probe packet was not due to its size, and
// continue searching in a much smaller interval now.
//
// If 3rd probe packet lost:
// |------------------|--------------------|---------|----------------------------|
// min lost[0] max
//
// Since in our example numPTOProbes = 2, and we lost 2 packets smaller than max, we
// conclude that this packet size is not supported on the path, and reduce the maximum
// value of the search interval.
//
// MTU discovery concludes once the interval min and max has been narrowed down to maxMTUDiff.
type mtuFinder struct {
lastProbeTime time.Time
mtuIncreased func(protocol.ByteCount)
rttStats *utils.RTTStats
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
current protocol.ByteCount
max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer)
min protocol.ByteCount
limit protocol.ByteCount
// on initialization, we treat the maximum size as the first "lost" packet
lost [maxLostMTUProbes]protocol.ByteCount
lastProbeWasLost bool
tracer *logging.ConnectionTracer
}
@ -47,33 +111,43 @@ func newMTUDiscoverer(
mtuIncreased func(protocol.ByteCount),
tracer *logging.ConnectionTracer,
) *mtuFinder {
return &mtuFinder{
f := &mtuFinder{
inFlight: protocol.InvalidByteCount,
current: start,
max: max,
min: start,
limit: max,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
tracer: tracer,
}
for i := range f.lost {
if i == 0 {
f.lost[i] = max
continue
}
f.lost[i] = protocol.InvalidByteCount
}
return f
}
func (f *mtuFinder) done() bool {
return f.max-f.current <= maxMTUDiff+1
return f.max()-f.min <= maxMTUDiff+1
}
func (f *mtuFinder) SetMax(max protocol.ByteCount) {
f.max = max
func (f *mtuFinder) max() protocol.ByteCount {
for i, v := range f.lost {
if v == protocol.InvalidByteCount {
return f.lost[i-1]
}
}
return f.lost[len(f.lost)-1]
}
func (f *mtuFinder) Start() {
if f.max == protocol.InvalidByteCount {
panic("invalid")
}
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
}
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
if f.max == 0 || f.lastProbeTime.IsZero() {
if f.lastProbeTime.IsZero() {
return false
}
if f.inFlight != protocol.InvalidByteCount || f.done() {
@ -83,7 +157,12 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
}
func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
size := (f.max + f.current) / 2
var size protocol.ByteCount
if f.lastProbeWasLost {
size = (f.min + f.lost[0]) / 2
} else {
size = (f.min + f.max()) / 2
}
f.lastProbeTime = time.Now()
f.inFlight = size
return ackhandler.Frame{
@ -93,7 +172,7 @@ func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
}
func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.current
return f.min
}
type mtuFinderAckHandler struct {
@ -108,7 +187,25 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
panic("OnAcked callback called although there's no MTU probe packet in flight")
}
h.inFlight = protocol.InvalidByteCount
h.current = size
h.min = size
h.lastProbeWasLost = false
// remove all values smaller than size from the lost array
var j int
for i, v := range h.lost {
if size < v {
j = i
break
}
}
if j > 0 {
for i := 0; i < len(h.lost); i++ {
if i+j < len(h.lost) {
h.lost[i] = h.lost[i+j]
} else {
h.lost[i] = protocol.InvalidByteCount
}
}
}
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
h.tracer.UpdatedMTU(size, h.done())
}
@ -120,6 +217,13 @@ func (h *mtuFinderAckHandler) OnLost(wire.Frame) {
if size == protocol.InvalidByteCount {
panic("OnLost callback called although there's no MTU probe packet in flight")
}
h.max = size
h.lastProbeWasLost = true
h.inFlight = protocol.InvalidByteCount
for i, v := range h.lost {
if size < v {
copy(h.lost[i+1:], h.lost[i:])
h.lost[i] = size
break
}
}
}

View file

@ -478,7 +478,6 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.ctxCancel(err)
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalWrite()

View file

@ -76,8 +76,12 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
connContext func(context.Context) context.Context
// set as a member, so they can be set in the tests
newConn func(
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
@ -92,7 +96,6 @@ type baseServer struct {
*handshake.TokenGenerator,
bool, /* client address validated by an address validation token */
*logging.ConnectionTracer,
ConnectionTracingID,
utils.Logger,
protocol.Version,
) quicConn
@ -231,6 +234,7 @@ func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
connContext func(context.Context) context.Context,
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
@ -243,6 +247,7 @@ func newServer(
) *baseServer {
s := &baseServer{
conn: conn,
connContext: connContext,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
@ -631,7 +636,26 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
var conn quicConn
tracingID := nextConnTracingID()
var cancel context.CancelCauseFunc
ctx, cancel1 := context.WithCancelCause(context.Background())
if s.connContext != nil {
ctx = s.connContext(ctx)
if ctx == nil {
panic("quic: ConnContext returned nil")
}
// There's no guarantee that the application returns a context
// that's derived from the context we passed into ConnContext.
// We need to make sure that both contexts are cancelled.
var cancel2 context.CancelCauseFunc
ctx, cancel2 = context.WithCancelCause(ctx)
cancel = func(cause error) {
cancel1(cause)
cancel2(cause)
}
} else {
cancel = cancel1
}
ctx = context.WithValue(ctx, ConnectionTracingKey, nextConnTracingID())
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
@ -639,7 +663,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
tracer = config.Tracer(ctx, protocol.PerspectiveServer, connID)
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
@ -647,6 +671,8 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,
@ -661,7 +687,6 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.tokenGenerator,
clientAddrVerified,
tracer,
tracingID,
s.logger,
hdr.Version,
)

View file

@ -89,6 +89,17 @@ type Transport struct {
// implementation of this callback (negating its return value).
VerifySourceAddress func(net.Addr) bool
// ConnContext is called when the server accepts a new connection.
// The context is closed when the connection is closed, or when the handshake fails for any reason.
// The context returned from the callback is used to derive every other context used during the
// lifetime of the connection:
// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
// * the context used in Config.Tracer
// * the context returned from Connection.Context
// * the context returned from SendStream.Context
// It is not used for dialed connections.
ConnContext func(context.Context) context.Context
// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Tracer *logging.Tracer
@ -168,6 +179,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
t.conn,
t.handlerMap,
t.connIDGenerator,
t.ConnContext,
tlsConf,
conf,
t.Tracer,