Update deps

This commit is contained in:
Frank Denis 2024-10-23 14:22:28 +02:00
parent 4ed02b02df
commit 3f551f4b66
56 changed files with 1094 additions and 479 deletions

View file

@ -1,6 +1,6 @@
BSD 2-Clause License
Copyright (c) 2018-2023, Frank Denis
Copyright (c) 2018-2024, Frank Denis
All rights reserved.
Redistribution and use in source and binary forms, with or without

View file

@ -43,3 +43,4 @@ issues:
- path: _test\.go
linters:
- exhaustive
- prealloc

View file

@ -6,16 +6,19 @@ import (
"github.com/quic-go/quic-go/quicvarint"
)
// CapsuleType is the type of the capsule.
// CapsuleType is the type of the capsule
type CapsuleType uint64
// CapsuleProtocolHeader is the header value used to advertise support for the capsule protocol
const CapsuleProtocolHeader = "Capsule-Protocol"
type exactReader struct {
R *io.LimitedReader
R io.LimitedReader
}
func (r *exactReader) Read(b []byte) (int, error) {
n, err := r.R.Read(b)
if r.R.N > 0 {
if err == io.EOF && r.R.N > 0 {
return n, io.ErrUnexpectedEOF
}
return n, err
@ -35,7 +38,7 @@ func (r *countingByteReader) ReadByte() (byte, error) {
}
// ParseCapsule parses the header of a Capsule.
// It returns an io.LimitedReader that can be used to read the Capsule value.
// It returns an io.Reader that can be used to read the Capsule value.
// The Capsule value must be read entirely (i.e. until the io.EOF) before using r again.
func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
cbr := countingByteReader{ByteReader: r}
@ -55,7 +58,7 @@ func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
}
return 0, nil, err
}
return CapsuleType(ct), &exactReader{R: io.LimitReader(r, int64(l)).(*io.LimitedReader)}, nil
return CapsuleType(ct), &exactReader{R: io.LimitedReader{R: r, N: int64(l)}}, nil
}
// WriteCapsule writes a capsule

View file

@ -9,7 +9,6 @@ import (
"net/http"
"net/http/httptrace"
"net/textproto"
"sync"
"time"
"github.com/quic-go/quic-go"
@ -38,102 +37,122 @@ var defaultQuicConfig = &quic.Config{
KeepAlivePeriod: 10 * time.Second,
}
// SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server.
type SingleDestinationRoundTripper struct {
Connection quic.Connection
// ClientConn is an HTTP/3 client doing requests to a single remote server.
type ClientConn struct {
connection
// Enable support for HTTP/3 datagrams (RFC 9297).
// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
EnableDatagrams bool
// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting enableDatagrams.
enableDatagrams bool
// Additional HTTP/3 settings.
// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
additionalSettings map[uint64]uint64
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
// maxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header.
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
maxResponseHeaderBytes uint64
// DisableCompression, if true, prevents the Transport from requesting compression with an
// disableCompression, if true, prevents the Transport from requesting compression with an
// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
// decoded in the Response.Body.
// However, if the user explicitly requested gzip it is not automatically uncompressed.
DisableCompression bool
disableCompression bool
Logger *slog.Logger
logger *slog.Logger
initOnce sync.Once
hconn *connection
requestWriter *requestWriter
decoder *qpack.Decoder
}
var _ http.RoundTripper = &SingleDestinationRoundTripper{}
var _ http.RoundTripper = &ClientConn{}
func (c *SingleDestinationRoundTripper) Start() Connection {
c.initOnce.Do(func() { c.init() })
return c.hconn
}
// Deprecated: SingleDestinationRoundTripper was renamed to ClientConn.
// It can be obtained by calling NewClientConn on a Transport.
type SingleDestinationRoundTripper = ClientConn
func (c *SingleDestinationRoundTripper) init() {
func newClientConn(
conn quic.Connection,
enableDatagrams bool,
additionalSettings map[uint64]uint64,
streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error),
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool),
maxResponseHeaderBytes int64,
disableCompression bool,
logger *slog.Logger,
) *ClientConn {
c := &ClientConn{
enableDatagrams: enableDatagrams,
additionalSettings: additionalSettings,
disableCompression: disableCompression,
logger: logger,
}
if maxResponseHeaderBytes <= 0 {
c.maxResponseHeaderBytes = defaultMaxResponseHeaderBytes
} else {
c.maxResponseHeaderBytes = uint64(maxResponseHeaderBytes)
}
c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
c.requestWriter = newRequestWriter()
c.hconn = newConnection(
c.Connection.Context(),
c.Connection,
c.EnableDatagrams,
c.connection = *newConnection(
conn.Context(),
conn,
c.enableDatagrams,
protocol.PerspectiveClient,
c.Logger,
c.logger,
0,
)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(c.hconn); err != nil {
if c.Logger != nil {
c.Logger.Debug("Setting up connection failed", "error", err)
if err := c.setupConn(); err != nil {
if c.logger != nil {
c.logger.Debug("Setting up connection failed", "error", err)
}
c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
}
}()
if c.StreamHijacker != nil {
go c.handleBidirectionalStreams()
if streamHijacker != nil {
go c.handleBidirectionalStreams(streamHijacker)
}
go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker)
go c.connection.handleUnidirectionalStreams(uniStreamHijacker)
return c
}
func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error {
// OpenRequestStream opens a new request stream on the HTTP/3 connection.
func (c *ClientConn) OpenRequestStream(ctx context.Context) (RequestStream, error) {
return c.connection.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes)
}
func (c *ClientConn) setupConn() error {
// open the control stream
str, err := conn.OpenUniStream()
str, err := c.connection.OpenUniStream()
if err != nil {
return err
}
b := make([]byte, 0, 64)
b = quicvarint.Append(b, streamTypeControlStream)
// send the SETTINGS frame
b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b)
b = (&settingsFrame{Datagram: c.enableDatagrams, Other: c.additionalSettings}).Append(b)
_, err = str.Write(b)
return err
}
func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)) {
for {
str, err := c.hconn.AcceptStream(context.Background())
str, err := c.connection.AcceptStream(context.Background())
if err != nil {
if c.Logger != nil {
c.Logger.Debug("accepting bidirectional stream failed", "error", err)
if c.logger != nil {
c.logger.Debug("accepting bidirectional stream failed", "error", err)
}
return
}
fp := &frameParser{
r: str,
conn: c.hconn,
conn: &c.connection,
unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) {
id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return c.StreamHijacker(ft, id, str, e)
id := c.connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return streamHijacker(ft, id, str, e)
},
}
go func() {
@ -141,26 +160,17 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
return
}
if err != nil {
if c.Logger != nil {
c.Logger.Debug("error handling stream", "error", err)
if c.logger != nil {
c.logger.Debug("error handling stream", "error", err)
}
}
c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}()
}
}
func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 {
if c.MaxResponseHeaderBytes <= 0 {
return defaultMaxResponseHeaderBytes
}
return uint64(c.MaxResponseHeaderBytes)
}
// RoundTrip executes a request and returns a response
func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
c.initOnce.Do(func() { c.init() })
func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
rsp, err := c.roundTrip(req)
if err != nil && req.Context().Err() != nil {
// if the context was canceled, return the context cancellation error
@ -169,7 +179,7 @@ func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Resp
return rsp, err
}
func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) {
func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) {
// Immediately send out this request, if this is a 0-RTT request.
switch req.Method {
case MethodGet0RTT:
@ -200,17 +210,23 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp
connCtx := c.Connection.Context()
// wait for the server's SETTINGS frame to arrive
select {
case <-c.hconn.ReceivedSettings():
case <-c.connection.ReceivedSettings():
case <-connCtx.Done():
return nil, context.Cause(connCtx)
}
if !c.hconn.Settings().EnableExtendedConnect {
if !c.connection.Settings().EnableExtendedConnect {
return nil, errors.New("http3: server didn't enable Extended CONNECT")
}
}
reqDone := make(chan struct{})
str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes())
str, err := c.connection.openRequestStream(
req.Context(),
c.requestWriter,
reqDone,
c.disableCompression,
c.maxResponseHeaderBytes,
)
if err != nil {
return nil, err
}
@ -238,12 +254,6 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp
return rsp, maybeReplaceError(err)
}
func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) {
c.initOnce.Do(func() { c.init() })
return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes())
}
// cancelingReader reads from the io.Reader.
// It cancels writing on the stream if any error other than io.EOF occurs.
type cancelingReader struct {
@ -259,7 +269,7 @@ func (r *cancelingReader) Read(b []byte) (int, error) {
return n, err
}
func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
func (c *ClientConn) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
defer body.Close()
buf := make([]byte, bodyCopyBufferSize)
sr := &cancelingReader{str: str, r: body}
@ -283,7 +293,7 @@ func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.Read
return err
}
func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
func (c *ClientConn) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
if err := str.SendRequestHeader(req); err != nil {
return nil, err
}
@ -299,8 +309,8 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques
contentLength = req.ContentLength
}
if err := c.sendRequestBody(str, req.Body, contentLength); err != nil {
if c.Logger != nil {
c.Logger.Debug("error writing request", "error", err)
if c.logger != nil {
c.logger.Debug("error writing request", "error", err)
}
}
str.Close()
@ -337,7 +347,7 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques
}
break
}
connState := c.hconn.ConnectionState().TLS
connState := c.connection.ConnectionState().TLS
res.TLS = &connState
res.Request = req
return res, nil

View file

@ -170,7 +170,7 @@ func (c *connection) CloseWithError(code quic.ApplicationErrorCode, msg string)
return c.Connection.CloseWithError(code, msg)
}
func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) {
func (c *connection) handleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) {
var (
rcvdControlStr atomic.Bool
rcvdQPACKEncoderStr atomic.Bool
@ -316,10 +316,12 @@ func (c *connection) receiveDatagrams() error {
}
// ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received.
// Settings can be optained from the Settings method after the channel was closed.
func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings }
// Settings returns the settings received on this connection.
// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
func (c *connection) Settings() *Settings { return c.settings }
// Context returns the context of the underlying QUIC connection.
func (c *connection) Context() context.Context { return c.ctx }

View file

@ -66,7 +66,8 @@ func (p *frameParser) ParseNext() (frame, error) {
return parseSettingsFrame(p.r, l)
case 0x3: // CANCEL_PUSH
case 0x5: // PUSH_PROMISE
case 0x7: // GOAWAY
case 0x7:
return parseGoAwayFrame(qr, l)
case 0xd: // MAX_PUSH_ID
case 0x2, 0x6, 0x8, 0x9:
p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
@ -194,3 +195,27 @@ func (f *settingsFrame) Append(b []byte) []byte {
}
return b
}
type goAwayFrame struct {
StreamID quic.StreamID
}
func parseGoAwayFrame(r io.ByteReader, l uint64) (*goAwayFrame, error) {
frame := &goAwayFrame{}
cbr := countingByteReader{ByteReader: r}
id, err := quicvarint.Read(&cbr)
if err != nil {
return nil, err
}
if cbr.Read != int(l) {
return nil, errors.New("GOAWAY frame: inconsistent length")
}
frame.StreamID = quic.StreamID(id)
return frame, nil
}
func (f *goAwayFrame) Append(b []byte) []byte {
b = quicvarint.Append(b, 0x7)
b = quicvarint.Append(b, uint64(quicvarint.Len(uint64(f.StreamID))))
return quicvarint.Append(b, uint64(f.StreamID))
}

View file

@ -13,6 +13,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
@ -45,6 +46,8 @@ const (
streamTypeQPACKDecoderStream = 3
)
const goawayTimeout = 5 * time.Second
// A QUICEarlyListener listens for incoming QUIC connections.
type QUICEarlyListener interface {
Accept(context.Context) (quic.EarlyConnection, error)
@ -213,7 +216,13 @@ type Server struct {
mutex sync.RWMutex
listeners map[*QUICEarlyListener]listenerInfo
closed bool
closed bool
closeCtx context.Context // canceled when the server is closed
closeCancel context.CancelFunc // cancels the closeCtx
graceCtx context.Context // canceled when the server is closed or gracefully closed
graceCancel context.CancelFunc // cancels the graceCtx
connCount atomic.Int64
connHandlingDone chan struct{}
altSvcHeader string
}
@ -265,8 +274,31 @@ func (s *Server) Serve(conn net.PacketConn) error {
return s.serveListener(ln)
}
// init initializes the contexts used for shutting down the server.
// It must be called with the mutex held.
func (s *Server) init() {
if s.closeCtx == nil {
s.closeCtx, s.closeCancel = context.WithCancel(context.Background())
s.graceCtx, s.graceCancel = context.WithCancel(s.closeCtx)
}
s.connHandlingDone = make(chan struct{}, 1)
}
func (s *Server) decreaseConnCount() {
if s.connCount.Add(-1) == 0 && s.graceCtx.Err() != nil {
close(s.connHandlingDone)
}
}
// ServeQUICConn serves a single QUIC connection.
func (s *Server) ServeQUICConn(conn quic.Connection) error {
s.mutex.Lock()
s.init()
s.mutex.Unlock()
s.connCount.Add(1)
defer s.decreaseConnCount()
return s.handleConn(conn)
}
@ -289,14 +321,17 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error {
func (s *Server) serveListener(ln QUICEarlyListener) error {
for {
conn, err := ln.Accept(context.Background())
if err == quic.ErrServerClosed {
conn, err := ln.Accept(s.graceCtx)
// server closed
if errors.Is(err, quic.ErrServerClosed) || s.graceCtx.Err() != nil {
return http.ErrServerClosed
}
if err != nil {
return err
}
s.connCount.Add(1)
go func() {
defer s.decreaseConnCount()
if err := s.handleConn(conn); err != nil {
if s.Logger != nil {
s.Logger.Debug("handling connection failed", "error", err)
@ -430,6 +465,7 @@ func (s *Server) addListener(l *QUICEarlyListener) error {
if s.listeners == nil {
s.listeners = make(map[*QUICEarlyListener]listenerInfo)
}
s.init()
laddr := (*l).Addr()
if port, err := extractPort(laddr.String()); err == nil {
@ -453,9 +489,12 @@ func (s *Server) removeListener(l *QUICEarlyListener) {
s.generateAltSvcHeader()
}
// handleConn handles the HTTP/3 exchange on a QUIC connection.
// It blocks until all HTTP handlers for all streams have returned.
func (s *Server) handleConn(conn quic.Connection) error {
// send a SETTINGS frame
str, err := conn.OpenUniStream()
// open the control stream and send a SETTINGS frame, it's also used to send a GOAWAY frame later
// when the server is gracefully closed
ctrlStr, err := conn.OpenUniStream()
if err != nil {
return fmt.Errorf("opening the control stream failed: %w", err)
}
@ -466,7 +505,7 @@ func (s *Server) handleConn(conn quic.Connection) error {
ExtendedConnect: true,
Other: s.AdditionalSettings,
}).Append(b)
str.Write(b)
ctrlStr.Write(b)
ctx := conn.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
@ -487,21 +526,60 @@ func (s *Server) handleConn(conn quic.Connection) error {
s.Logger,
s.IdleTimeout,
)
go hconn.HandleUnidirectionalStreams(s.UniStreamHijacker)
go hconn.handleUnidirectionalStreams(s.UniStreamHijacker)
var nextStreamID quic.StreamID
var wg sync.WaitGroup
var handleErr error
// Process all requests immediately.
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
for {
str, datagrams, err := hconn.acceptStream(context.Background())
str, datagrams, err := hconn.acceptStream(s.graceCtx)
if err != nil {
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) {
return nil
// server (not gracefully) closed, close the connection immediately
if s.closeCtx.Err() != nil {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
handleErr = http.ErrServerClosed
break
}
return fmt.Errorf("accepting stream failed: %w", err)
// gracefully closed, send GOAWAY frame and wait for requests to complete or grace period to end
// new requests will be rejected and shouldn't be sent
if s.graceCtx.Err() != nil {
b = (&goAwayFrame{StreamID: nextStreamID}).Append(b[:0])
// set a deadline to send the GOAWAY frame
ctrlStr.SetWriteDeadline(time.Now().Add(goawayTimeout))
ctrlStr.Write(b)
select {
case <-hconn.Context().Done():
// we expect the client to eventually close the connection after receiving the GOAWAY
case <-s.closeCtx.Done():
// close the connection after graceful period
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
}
handleErr = http.ErrServerClosed
break
}
var appErr *quic.ApplicationError
if !errors.As(err, &appErr) || appErr.ErrorCode != quic.ApplicationErrorCode(ErrCodeNoError) {
handleErr = fmt.Errorf("accepting stream failed: %w", err)
}
break
}
go s.handleRequest(hconn, str, datagrams, hconn.decoder)
nextStreamID = str.StreamID() + 4
wg.Add(1)
go func() {
// handleRequest will return once the request has been handled,
// or the underlying connection is closed
defer wg.Done()
s.handleRequest(hconn, str, datagrams, hconn.decoder)
}()
}
wg.Wait()
return handleErr
}
func (s *Server) maxHeaderBytes() uint64 {
@ -606,7 +684,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
if logger == nil {
logger = slog.Default()
}
logger.Error("http: panic serving", "arg", p, "trace", buf)
logger.Error("http: panic serving", "arg", p, "trace", string(buf))
}
}()
handler.ServeHTTP(r, req)
@ -643,11 +721,17 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
// It is the caller's responsibility to close any connection passed to ServeQUICConn.
func (s *Server) Close() error {
s.mutex.Lock()
defer s.mutex.Unlock()
s.closed = true
// server is never used
if s.closeCtx == nil {
return nil
}
s.closeCancel()
var err error
for ln := range s.listeners {
@ -655,14 +739,44 @@ func (s *Server) Close() error {
err = cerr
}
}
if s.connCount.Load() == 0 {
return err
}
// wait for all connections to be closed
<-s.connHandlingDone
return err
}
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) CloseGracefully(timeout time.Duration) error {
// TODO: implement
return nil
// Shutdown shuts down the server gracefully.
// The server sends a GOAWAY frame first, then or for all running requests to complete.
// Shutdown in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) Shutdown(ctx context.Context) error {
s.mutex.Lock()
s.closed = true
// server is never used
if s.closeCtx == nil {
s.mutex.Unlock()
return nil
}
s.graceCancel()
s.mutex.Unlock()
if s.connCount.Load() == 0 {
return s.Close()
}
select {
case <-s.connHandlingDone: // all connections were closed
// When receiving a GOAWAY frame, HTTP/3 clients are expected to close the connection
// once all requests were successfully handled...
return s.Close()
case <-ctx.Done():
// ... however, clients handling long-lived requests (and misbehaving clients),
// might not do so before the context is cancelled.
// In this case, we close the server, which closes all existing connections
// (expect those passed to ServeQUICConn).
_ = s.Close()
return ctx.Err()
}
}
// ErrNoAltSvcPort is the error returned by SetQUICHeaders when no port was found
@ -690,11 +804,6 @@ func (s *Server) SetQUICHeaders(hdr http.Header) error {
return nil
}
// Deprecated: use SetQUICHeaders instead.
func (s *Server) SetQuicHeaders(hdr http.Header) error {
return s.SetQUICHeaders(hdr)
}
// ListenAndServeQUIC listens on the UDP network address addr and calls the
// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
// used when handler is nil.
@ -706,11 +815,6 @@ func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) er
return server.ListenAndServeTLS(certFile, keyFile)
}
// Deprecated: use ListenAndServeTLS instead.
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
return ListenAndServeTLS(addr, certFile, keyFile, handler)
}
// ListenAndServeTLS listens on the given network address for both TLS/TCP and QUIC
// connections in parallel. It returns if one of the two returns an error.
// http.DefaultServeMux is used when handler is nil.

View file

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"strings"
@ -30,7 +31,7 @@ type Settings struct {
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
// OnlyCachedConn controls whether the Transport may create a new QUIC connection.
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
OnlyCachedConn bool
}
@ -59,10 +60,8 @@ func (r *roundTripperWithCount) Close() error {
return nil
}
// RoundTripper implements the http.RoundTripper interface
type RoundTripper struct {
mutex sync.Mutex
// Transport implements the http.RoundTripper interface
type Transport struct {
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
@ -97,6 +96,13 @@ type RoundTripper struct {
// However, if the user explicitly requested gzip it is not automatically uncompressed.
DisableCompression bool
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
Logger *slog.Logger
mutex sync.Mutex
initOnce sync.Once
initErr error
@ -107,18 +113,56 @@ type RoundTripper struct {
}
var (
_ http.RoundTripper = &RoundTripper{}
_ io.Closer = &RoundTripper{}
_ http.RoundTripper = &Transport{}
_ io.Closer = &Transport{}
)
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
// Deprecated: RoundTripper was renamed to Transport.
type RoundTripper = Transport
// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
func (t *Transport) init() error {
if t.newClient == nil {
t.newClient = func(conn quic.EarlyConnection) singleRoundTripper {
return newClientConn(
conn,
t.EnableDatagrams,
t.AdditionalSettings,
t.StreamHijacker,
t.UniStreamHijacker,
t.MaxResponseHeaderBytes,
t.DisableCompression,
t.Logger,
)
}
}
if t.QUICConfig == nil {
t.QUICConfig = defaultQuicConfig.Clone()
t.QUICConfig.EnableDatagrams = t.EnableDatagrams
}
if t.EnableDatagrams && !t.QUICConfig.EnableDatagrams {
return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
}
if len(t.QUICConfig.Versions) == 0 {
t.QUICConfig = t.QUICConfig.Clone()
t.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
}
if len(t.QUICConfig.Versions) != 1 {
return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
}
if t.QUICConfig.MaxIncomingStreams == 0 {
t.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
}
return nil
}
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
r.initOnce.Do(func() { r.initErr = r.init() })
if r.initErr != nil {
return nil, r.initErr
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
t.initOnce.Do(func() { t.initErr = t.init() })
if t.initErr != nil {
return nil, t.initErr
}
if req.URL == nil {
@ -154,7 +198,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
}
hostname := authorityAddr(hostnameFromURL(req.URL))
cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn)
cl, isReused, err := t.getClient(req.Context(), hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
@ -166,7 +210,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
}
if cl.dialErr != nil {
r.removeClient(hostname)
t.removeClient(hostname)
return nil, cl.dialErr
}
defer cl.useCount.Add(-1)
@ -176,12 +220,12 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
// so we remove the client from the cache so that subsequent trips reconnect
// context cancelation is excluded as is does not signify a connection error
if !errors.Is(err, context.Canceled) {
r.removeClient(hostname)
t.removeClient(hostname)
}
if isReused {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return r.RoundTripOpt(req, opt)
return t.RoundTripOpt(req, opt)
}
}
}
@ -189,51 +233,19 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
}
// RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) init() error {
if r.newClient == nil {
r.newClient = func(conn quic.EarlyConnection) singleRoundTripper {
return &SingleDestinationRoundTripper{
Connection: conn,
EnableDatagrams: r.EnableDatagrams,
DisableCompression: r.DisableCompression,
AdditionalSettings: r.AdditionalSettings,
MaxResponseHeaderBytes: r.MaxResponseHeaderBytes,
}
}
}
if r.QUICConfig == nil {
r.QUICConfig = defaultQuicConfig.Clone()
r.QUICConfig.EnableDatagrams = r.EnableDatagrams
}
if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams {
return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
}
if len(r.QUICConfig.Versions) == 0 {
r.QUICConfig = r.QUICConfig.Clone()
r.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
}
if len(r.QUICConfig.Versions) != 1 {
return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
}
if r.QUICConfig.MaxIncomingStreams == 0 {
r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
}
return nil
}
func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
t.mutex.Lock()
defer t.mutex.Unlock()
func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]*roundTripperWithCount)
if t.clients == nil {
t.clients = make(map[string]*roundTripperWithCount)
}
cl, ok := r.clients[hostname]
cl, ok := t.clients[hostname]
if !ok {
if onlyCached {
return nil, false, ErrNoCachedConn
@ -246,7 +258,7 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
go func() {
defer close(cl.dialing)
defer cancel()
conn, rt, err := r.dial(ctx, hostname)
conn, rt, err := t.dial(ctx, hostname)
if err != nil {
cl.dialErr = err
return
@ -254,12 +266,12 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
cl.conn = conn
cl.rt = rt
}()
r.clients[hostname] = cl
t.clients[hostname] = cl
}
select {
case <-cl.dialing:
if cl.dialErr != nil {
delete(r.clients, hostname)
delete(t.clients, hostname)
return nil, false, cl.dialErr
}
select {
@ -273,12 +285,12 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache
return cl, isReused, nil
}
func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) {
func (t *Transport) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) {
var tlsConf *tls.Config
if r.TLSClientConfig == nil {
if t.TLSClientConfig == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = r.TLSClientConfig.Clone()
tlsConf = t.TLSClientConfig.Clone()
}
if tlsConf.ServerName == "" {
sni, _, err := net.SplitHostPort(hostname)
@ -289,61 +301,79 @@ func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyCon
tlsConf.ServerName = sni
}
// Replace existing ALPNs by H3
tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])}
tlsConf.NextProtos = []string{versionToALPN(t.QUICConfig.Versions[0])}
dial := r.Dial
dial := t.Dial
if dial == nil {
if r.transport == nil {
if t.transport == nil {
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, nil, err
}
r.transport = &quic.Transport{Conn: udpConn}
t.transport = &quic.Transport{Conn: udpConn}
}
dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
return t.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
}
}
conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig)
conn, err := dial(ctx, hostname, tlsConf, t.QUICConfig)
if err != nil {
return nil, nil, err
}
return conn, r.newClient(conn), nil
return conn, t.newClient(conn), nil
}
func (r *RoundTripper) removeClient(hostname string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
func (t *Transport) removeClient(hostname string) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.clients == nil {
return
}
delete(r.clients, hostname)
delete(t.clients, hostname)
}
// Close closes the QUIC connections that this RoundTripper has used.
// It also closes the underlying UDPConn if it is not nil.
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, cl := range r.clients {
// NewClientConn creates a new HTTP/3 client connection on top of a QUIC connection.
// Most users should use RoundTrip instead of creating a connection directly.
// Specifically, it is not needed to perform GET, POST, HEAD and CONNECT requests.
//
// Obtaining a ClientConn is only needed for more advanced use cases, such as
// using Extended CONNECT for WebTransport or the various MASQUE protocols.
func (t *Transport) NewClientConn(conn quic.Connection) *ClientConn {
return newClientConn(
conn,
t.EnableDatagrams,
t.AdditionalSettings,
t.StreamHijacker,
t.UniStreamHijacker,
t.MaxResponseHeaderBytes,
t.DisableCompression,
t.Logger,
)
}
// Close closes the QUIC connections that this Transport has used.
func (t *Transport) Close() error {
t.mutex.Lock()
defer t.mutex.Unlock()
for _, cl := range t.clients {
if err := cl.Close(); err != nil {
return err
}
}
r.clients = nil
if r.transport != nil {
if err := r.transport.Close(); err != nil {
t.clients = nil
if t.transport != nil {
if err := t.transport.Close(); err != nil {
return err
}
if err := r.transport.Conn.Close(); err != nil {
if err := t.transport.Conn.Close(); err != nil {
return err
}
r.transport = nil
t.transport = nil
}
return nil
}
@ -376,13 +406,17 @@ func isNotToken(r rune) bool {
return !httpguts.IsTokenRune(r)
}
func (r *RoundTripper) CloseIdleConnections() {
r.mutex.Lock()
defer r.mutex.Unlock()
for hostname, cl := range r.clients {
// CloseIdleConnections closes any QUIC connections in the transport's pool that are currently idle.
// An idle connection is one that was previously used for requests but is now sitting unused.
// This method does not interrupt any connections currently in use.
// It also does not affect connections obtained via NewClientConn.
func (t *Transport) CloseIdleConnections() {
t.mutex.Lock()
defer t.mutex.Unlock()
for hostname, cl := range t.clients {
if cl.useCount.Load() == 0 {
cl.Close()
delete(r.clients, hostname)
delete(t.clients, hostname)
}
}
}

View file

@ -83,6 +83,29 @@ const (
EventHandshakeComplete
)
func (k EventKind) String() string {
switch k {
case EventNoEvent:
return "EventNoEvent"
case EventWriteInitialData:
return "EventWriteInitialData"
case EventWriteHandshakeData:
return "EventWriteHandshakeData"
case EventReceivedReadKeys:
return "EventReceivedReadKeys"
case EventDiscard0RTTKeys:
return "EventDiscard0RTTKeys"
case EventReceivedTransportParameters:
return "EventReceivedTransportParameters"
case EventRestoredTransportParameters:
return "EventRestoredTransportParameters"
case EventHandshakeComplete:
return "EventHandshakeComplete"
default:
return "Unknown EventKind"
}
}
// Event is a handshake event.
type Event struct {
Kind EventKind

View file

@ -27,11 +27,6 @@ type RTTStats struct {
maxAckDelay time.Duration
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{}
}
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
@ -113,19 +108,3 @@ func (r *RTTStats) SetInitialRTT(t time.Duration) {
r.smoothedRTT = t
r.latestRTT = t
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
r.minRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
// is larger. The mean deviation is increased to the most recent deviation if
// it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = max(r.meanDeviation, (r.smoothedRTT - r.latestRTT).Abs())
r.smoothedRTT = max(r.smoothedRTT, r.latestRTT)
}

View file

@ -253,6 +253,9 @@ func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNe
if s.cancelledLocally { // duplicate call to CancelRead
return false
}
if s.closeForShutdownErr != nil {
return false
}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return false

View file

@ -423,6 +423,10 @@ func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
s.mutex.Lock()
if s.closeForShutdownErr != nil {
s.mutex.Unlock()
return
}
if !remote {
s.cancellationFlagged = true
if s.cancelWriteErr != nil {