diff --git a/go.mod b/go.mod index 7b8df617..5624b631 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/miekg/dns v1.1.62 github.com/opencoff/go-sieve v0.2.1 github.com/powerman/check v1.7.0 - github.com/quic-go/quic-go v0.47.0 + github.com/quic-go/quic-go v0.48.1 golang.org/x/crypto v0.27.0 golang.org/x/net v0.29.0 golang.org/x/sys v0.25.0 diff --git a/go.sum b/go.sum index 2bee31b1..ba35f5be 100644 --- a/go.sum +++ b/go.sum @@ -73,8 +73,8 @@ github.com/powerman/deepequal v0.1.0 h1:sVwtyTsBuYIvdbLR1O2wzRY63YgPqdGZmk/o80l+ github.com/powerman/deepequal v0.1.0/go.mod h1:3k7aG/slufBhUANdN67o/UPg8i5YaiJ6FmibWX0cn04= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.47.0 h1:yXs3v7r2bm1wmPTYNLKAAJTHMYkPEsfYJmTazXrCZ7Y= -github.com/quic-go/quic-go v0.47.0/go.mod h1:3bCapYsJvXGZcipOHuu7plYtaV6tnF+z7wIFsU0WK9E= +github.com/quic-go/quic-go v0.48.1 h1:y/8xmfWI9qmGTc+lBr4jKRUWLGSlSigv847ULJ4hYXA= +github.com/quic-go/quic-go v0.48.1/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs= github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg= diff --git a/vendor/github.com/quic-go/quic-go/.golangci.yml b/vendor/github.com/quic-go/quic-go/.golangci.yml index 7c4b71c0..63b40cc3 100644 --- a/vendor/github.com/quic-go/quic-go/.golangci.yml +++ b/vendor/github.com/quic-go/quic-go/.golangci.yml @@ -43,3 +43,4 @@ issues: - path: _test\.go linters: - exhaustive + - prealloc diff --git a/vendor/github.com/quic-go/quic-go/http3/capsule.go b/vendor/github.com/quic-go/quic-go/http3/capsule.go index 69d4037a..132e6c32 100644 --- a/vendor/github.com/quic-go/quic-go/http3/capsule.go +++ b/vendor/github.com/quic-go/quic-go/http3/capsule.go @@ -6,16 +6,19 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -// CapsuleType is the type of the capsule. +// CapsuleType is the type of the capsule type CapsuleType uint64 +// CapsuleProtocolHeader is the header value used to advertise support for the capsule protocol +const CapsuleProtocolHeader = "Capsule-Protocol" + type exactReader struct { - R *io.LimitedReader + R io.LimitedReader } func (r *exactReader) Read(b []byte) (int, error) { n, err := r.R.Read(b) - if r.R.N > 0 { + if err == io.EOF && r.R.N > 0 { return n, io.ErrUnexpectedEOF } return n, err @@ -35,7 +38,7 @@ func (r *countingByteReader) ReadByte() (byte, error) { } // ParseCapsule parses the header of a Capsule. -// It returns an io.LimitedReader that can be used to read the Capsule value. +// It returns an io.Reader that can be used to read the Capsule value. // The Capsule value must be read entirely (i.e. until the io.EOF) before using r again. func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) { cbr := countingByteReader{ByteReader: r} @@ -55,7 +58,7 @@ func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) { } return 0, nil, err } - return CapsuleType(ct), &exactReader{R: io.LimitReader(r, int64(l)).(*io.LimitedReader)}, nil + return CapsuleType(ct), &exactReader{R: io.LimitedReader{R: r, N: int64(l)}}, nil } // WriteCapsule writes a capsule diff --git a/vendor/github.com/quic-go/quic-go/http3/client.go b/vendor/github.com/quic-go/quic-go/http3/client.go index e60acffe..83502240 100644 --- a/vendor/github.com/quic-go/quic-go/http3/client.go +++ b/vendor/github.com/quic-go/quic-go/http3/client.go @@ -9,7 +9,6 @@ import ( "net/http" "net/http/httptrace" "net/textproto" - "sync" "time" "github.com/quic-go/quic-go" @@ -38,102 +37,122 @@ var defaultQuicConfig = &quic.Config{ KeepAlivePeriod: 10 * time.Second, } -// SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server. -type SingleDestinationRoundTripper struct { - Connection quic.Connection +// ClientConn is an HTTP/3 client doing requests to a single remote server. +type ClientConn struct { + connection // Enable support for HTTP/3 datagrams (RFC 9297). - // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. - EnableDatagrams bool + // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting enableDatagrams. + enableDatagrams bool // Additional HTTP/3 settings. // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). - AdditionalSettings map[uint64]uint64 - StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error) - UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) + additionalSettings map[uint64]uint64 - // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // maxResponseHeaderBytes specifies a limit on how many response bytes are // allowed in the server's response header. - // Zero means to use a default limit. - MaxResponseHeaderBytes int64 + maxResponseHeaderBytes uint64 - // DisableCompression, if true, prevents the Transport from requesting compression with an + // disableCompression, if true, prevents the Transport from requesting compression with an // "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value. // If the Transport requests gzip on its own and gets a gzipped response, it's transparently // decoded in the Response.Body. // However, if the user explicitly requested gzip it is not automatically uncompressed. - DisableCompression bool + disableCompression bool - Logger *slog.Logger + logger *slog.Logger - initOnce sync.Once - hconn *connection requestWriter *requestWriter decoder *qpack.Decoder } -var _ http.RoundTripper = &SingleDestinationRoundTripper{} +var _ http.RoundTripper = &ClientConn{} -func (c *SingleDestinationRoundTripper) Start() Connection { - c.initOnce.Do(func() { c.init() }) - return c.hconn -} +// Deprecated: SingleDestinationRoundTripper was renamed to ClientConn. +// It can be obtained by calling NewClientConn on a Transport. +type SingleDestinationRoundTripper = ClientConn -func (c *SingleDestinationRoundTripper) init() { +func newClientConn( + conn quic.Connection, + enableDatagrams bool, + additionalSettings map[uint64]uint64, + streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error), + uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool), + maxResponseHeaderBytes int64, + disableCompression bool, + logger *slog.Logger, +) *ClientConn { + c := &ClientConn{ + enableDatagrams: enableDatagrams, + additionalSettings: additionalSettings, + disableCompression: disableCompression, + logger: logger, + } + if maxResponseHeaderBytes <= 0 { + c.maxResponseHeaderBytes = defaultMaxResponseHeaderBytes + } else { + c.maxResponseHeaderBytes = uint64(maxResponseHeaderBytes) + } c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {}) c.requestWriter = newRequestWriter() - c.hconn = newConnection( - c.Connection.Context(), - c.Connection, - c.EnableDatagrams, + c.connection = *newConnection( + conn.Context(), + conn, + c.enableDatagrams, protocol.PerspectiveClient, - c.Logger, + c.logger, 0, ) // send the SETTINGs frame, using 0-RTT data, if possible go func() { - if err := c.setupConn(c.hconn); err != nil { - if c.Logger != nil { - c.Logger.Debug("Setting up connection failed", "error", err) + if err := c.setupConn(); err != nil { + if c.logger != nil { + c.logger.Debug("Setting up connection failed", "error", err) } - c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") + c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") } }() - if c.StreamHijacker != nil { - go c.handleBidirectionalStreams() + if streamHijacker != nil { + go c.handleBidirectionalStreams(streamHijacker) } - go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker) + go c.connection.handleUnidirectionalStreams(uniStreamHijacker) + return c } -func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error { +// OpenRequestStream opens a new request stream on the HTTP/3 connection. +func (c *ClientConn) OpenRequestStream(ctx context.Context) (RequestStream, error) { + return c.connection.openRequestStream(ctx, c.requestWriter, nil, c.disableCompression, c.maxResponseHeaderBytes) +} + +func (c *ClientConn) setupConn() error { // open the control stream - str, err := conn.OpenUniStream() + str, err := c.connection.OpenUniStream() if err != nil { return err } b := make([]byte, 0, 64) b = quicvarint.Append(b, streamTypeControlStream) // send the SETTINGS frame - b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b) + b = (&settingsFrame{Datagram: c.enableDatagrams, Other: c.additionalSettings}).Append(b) _, err = str.Write(b) return err } -func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { +func (c *ClientConn) handleBidirectionalStreams(streamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)) { for { - str, err := c.hconn.AcceptStream(context.Background()) + str, err := c.connection.AcceptStream(context.Background()) if err != nil { - if c.Logger != nil { - c.Logger.Debug("accepting bidirectional stream failed", "error", err) + if c.logger != nil { + c.logger.Debug("accepting bidirectional stream failed", "error", err) } return } fp := &frameParser{ r: str, - conn: c.hconn, + conn: &c.connection, unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) { - id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) - return c.StreamHijacker(ft, id, str, e) + id := c.connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) + return streamHijacker(ft, id, str, e) }, } go func() { @@ -141,26 +160,17 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { return } if err != nil { - if c.Logger != nil { - c.Logger.Debug("error handling stream", "error", err) + if c.logger != nil { + c.logger.Debug("error handling stream", "error", err) } } - c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + c.connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") }() } } -func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 { - if c.MaxResponseHeaderBytes <= 0 { - return defaultMaxResponseHeaderBytes - } - return uint64(c.MaxResponseHeaderBytes) -} - // RoundTrip executes a request and returns a response -func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - c.initOnce.Do(func() { c.init() }) - +func (c *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { rsp, err := c.roundTrip(req) if err != nil && req.Context().Err() != nil { // if the context was canceled, return the context cancellation error @@ -169,7 +179,7 @@ func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Resp return rsp, err } -func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) { +func (c *ClientConn) roundTrip(req *http.Request) (*http.Response, error) { // Immediately send out this request, if this is a 0-RTT request. switch req.Method { case MethodGet0RTT: @@ -200,17 +210,23 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp connCtx := c.Connection.Context() // wait for the server's SETTINGS frame to arrive select { - case <-c.hconn.ReceivedSettings(): + case <-c.connection.ReceivedSettings(): case <-connCtx.Done(): return nil, context.Cause(connCtx) } - if !c.hconn.Settings().EnableExtendedConnect { + if !c.connection.Settings().EnableExtendedConnect { return nil, errors.New("http3: server didn't enable Extended CONNECT") } } reqDone := make(chan struct{}) - str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes()) + str, err := c.connection.openRequestStream( + req.Context(), + c.requestWriter, + reqDone, + c.disableCompression, + c.maxResponseHeaderBytes, + ) if err != nil { return nil, err } @@ -238,12 +254,6 @@ func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Resp return rsp, maybeReplaceError(err) } -func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) { - c.initOnce.Do(func() { c.init() }) - - return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes()) -} - // cancelingReader reads from the io.Reader. // It cancels writing on the stream if any error other than io.EOF occurs. type cancelingReader struct { @@ -259,7 +269,7 @@ func (r *cancelingReader) Read(b []byte) (int, error) { return n, err } -func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error { +func (c *ClientConn) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error { defer body.Close() buf := make([]byte, bodyCopyBufferSize) sr := &cancelingReader{str: str, r: body} @@ -283,7 +293,7 @@ func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.Read return err } -func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) { +func (c *ClientConn) doRequest(req *http.Request, str *requestStream) (*http.Response, error) { if err := str.SendRequestHeader(req); err != nil { return nil, err } @@ -299,8 +309,8 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques contentLength = req.ContentLength } if err := c.sendRequestBody(str, req.Body, contentLength); err != nil { - if c.Logger != nil { - c.Logger.Debug("error writing request", "error", err) + if c.logger != nil { + c.logger.Debug("error writing request", "error", err) } } str.Close() @@ -337,7 +347,7 @@ func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *reques } break } - connState := c.hconn.ConnectionState().TLS + connState := c.connection.ConnectionState().TLS res.TLS = &connState res.Request = req return res, nil diff --git a/vendor/github.com/quic-go/quic-go/http3/conn.go b/vendor/github.com/quic-go/quic-go/http3/conn.go index 0f372b0d..ec62ed3f 100644 --- a/vendor/github.com/quic-go/quic-go/http3/conn.go +++ b/vendor/github.com/quic-go/quic-go/http3/conn.go @@ -170,7 +170,7 @@ func (c *connection) CloseWithError(code quic.ApplicationErrorCode, msg string) return c.Connection.CloseWithError(code, msg) } -func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) { +func (c *connection) handleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) { var ( rcvdControlStr atomic.Bool rcvdQPACKEncoderStr atomic.Bool @@ -316,10 +316,12 @@ func (c *connection) receiveDatagrams() error { } // ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received. +// Settings can be optained from the Settings method after the channel was closed. func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings } // Settings returns the settings received on this connection. // It is only valid to call this function after the channel returned by ReceivedSettings was closed. func (c *connection) Settings() *Settings { return c.settings } +// Context returns the context of the underlying QUIC connection. func (c *connection) Context() context.Context { return c.ctx } diff --git a/vendor/github.com/quic-go/quic-go/http3/frames.go b/vendor/github.com/quic-go/quic-go/http3/frames.go index 66cba68c..b54afb31 100644 --- a/vendor/github.com/quic-go/quic-go/http3/frames.go +++ b/vendor/github.com/quic-go/quic-go/http3/frames.go @@ -66,7 +66,8 @@ func (p *frameParser) ParseNext() (frame, error) { return parseSettingsFrame(p.r, l) case 0x3: // CANCEL_PUSH case 0x5: // PUSH_PROMISE - case 0x7: // GOAWAY + case 0x7: + return parseGoAwayFrame(qr, l) case 0xd: // MAX_PUSH_ID case 0x2, 0x6, 0x8, 0x9: p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") @@ -194,3 +195,27 @@ func (f *settingsFrame) Append(b []byte) []byte { } return b } + +type goAwayFrame struct { + StreamID quic.StreamID +} + +func parseGoAwayFrame(r io.ByteReader, l uint64) (*goAwayFrame, error) { + frame := &goAwayFrame{} + cbr := countingByteReader{ByteReader: r} + id, err := quicvarint.Read(&cbr) + if err != nil { + return nil, err + } + if cbr.Read != int(l) { + return nil, errors.New("GOAWAY frame: inconsistent length") + } + frame.StreamID = quic.StreamID(id) + return frame, nil +} + +func (f *goAwayFrame) Append(b []byte) []byte { + b = quicvarint.Append(b, 0x7) + b = quicvarint.Append(b, uint64(quicvarint.Len(uint64(f.StreamID)))) + return quicvarint.Append(b, uint64(f.StreamID)) +} diff --git a/vendor/github.com/quic-go/quic-go/http3/server.go b/vendor/github.com/quic-go/quic-go/http3/server.go index 9f285b6e..097a8005 100644 --- a/vendor/github.com/quic-go/quic-go/http3/server.go +++ b/vendor/github.com/quic-go/quic-go/http3/server.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -45,6 +46,8 @@ const ( streamTypeQPACKDecoderStream = 3 ) +const goawayTimeout = 5 * time.Second + // A QUICEarlyListener listens for incoming QUIC connections. type QUICEarlyListener interface { Accept(context.Context) (quic.EarlyConnection, error) @@ -213,7 +216,13 @@ type Server struct { mutex sync.RWMutex listeners map[*QUICEarlyListener]listenerInfo - closed bool + closed bool + closeCtx context.Context // canceled when the server is closed + closeCancel context.CancelFunc // cancels the closeCtx + graceCtx context.Context // canceled when the server is closed or gracefully closed + graceCancel context.CancelFunc // cancels the graceCtx + connCount atomic.Int64 + connHandlingDone chan struct{} altSvcHeader string } @@ -265,8 +274,31 @@ func (s *Server) Serve(conn net.PacketConn) error { return s.serveListener(ln) } +// init initializes the contexts used for shutting down the server. +// It must be called with the mutex held. +func (s *Server) init() { + if s.closeCtx == nil { + s.closeCtx, s.closeCancel = context.WithCancel(context.Background()) + s.graceCtx, s.graceCancel = context.WithCancel(s.closeCtx) + } + s.connHandlingDone = make(chan struct{}, 1) +} + +func (s *Server) decreaseConnCount() { + if s.connCount.Add(-1) == 0 && s.graceCtx.Err() != nil { + close(s.connHandlingDone) + } +} + // ServeQUICConn serves a single QUIC connection. func (s *Server) ServeQUICConn(conn quic.Connection) error { + s.mutex.Lock() + s.init() + s.mutex.Unlock() + + s.connCount.Add(1) + defer s.decreaseConnCount() + return s.handleConn(conn) } @@ -289,14 +321,17 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error { func (s *Server) serveListener(ln QUICEarlyListener) error { for { - conn, err := ln.Accept(context.Background()) - if err == quic.ErrServerClosed { + conn, err := ln.Accept(s.graceCtx) + // server closed + if errors.Is(err, quic.ErrServerClosed) || s.graceCtx.Err() != nil { return http.ErrServerClosed } if err != nil { return err } + s.connCount.Add(1) go func() { + defer s.decreaseConnCount() if err := s.handleConn(conn); err != nil { if s.Logger != nil { s.Logger.Debug("handling connection failed", "error", err) @@ -430,6 +465,7 @@ func (s *Server) addListener(l *QUICEarlyListener) error { if s.listeners == nil { s.listeners = make(map[*QUICEarlyListener]listenerInfo) } + s.init() laddr := (*l).Addr() if port, err := extractPort(laddr.String()); err == nil { @@ -453,9 +489,12 @@ func (s *Server) removeListener(l *QUICEarlyListener) { s.generateAltSvcHeader() } +// handleConn handles the HTTP/3 exchange on a QUIC connection. +// It blocks until all HTTP handlers for all streams have returned. func (s *Server) handleConn(conn quic.Connection) error { - // send a SETTINGS frame - str, err := conn.OpenUniStream() + // open the control stream and send a SETTINGS frame, it's also used to send a GOAWAY frame later + // when the server is gracefully closed + ctrlStr, err := conn.OpenUniStream() if err != nil { return fmt.Errorf("opening the control stream failed: %w", err) } @@ -466,7 +505,7 @@ func (s *Server) handleConn(conn quic.Connection) error { ExtendedConnect: true, Other: s.AdditionalSettings, }).Append(b) - str.Write(b) + ctrlStr.Write(b) ctx := conn.Context() ctx = context.WithValue(ctx, ServerContextKey, s) @@ -487,21 +526,60 @@ func (s *Server) handleConn(conn quic.Connection) error { s.Logger, s.IdleTimeout, ) - go hconn.HandleUnidirectionalStreams(s.UniStreamHijacker) + go hconn.handleUnidirectionalStreams(s.UniStreamHijacker) + var nextStreamID quic.StreamID + var wg sync.WaitGroup + var handleErr error // Process all requests immediately. // It's the client's responsibility to decide which requests are eligible for 0-RTT. for { - str, datagrams, err := hconn.acceptStream(context.Background()) + str, datagrams, err := hconn.acceptStream(s.graceCtx) if err != nil { - var appErr *quic.ApplicationError - if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) { - return nil + // server (not gracefully) closed, close the connection immediately + if s.closeCtx.Err() != nil { + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") + handleErr = http.ErrServerClosed + break } - return fmt.Errorf("accepting stream failed: %w", err) + + // gracefully closed, send GOAWAY frame and wait for requests to complete or grace period to end + // new requests will be rejected and shouldn't be sent + if s.graceCtx.Err() != nil { + b = (&goAwayFrame{StreamID: nextStreamID}).Append(b[:0]) + // set a deadline to send the GOAWAY frame + ctrlStr.SetWriteDeadline(time.Now().Add(goawayTimeout)) + ctrlStr.Write(b) + + select { + case <-hconn.Context().Done(): + // we expect the client to eventually close the connection after receiving the GOAWAY + case <-s.closeCtx.Done(): + // close the connection after graceful period + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") + } + handleErr = http.ErrServerClosed + break + } + + var appErr *quic.ApplicationError + if !errors.As(err, &appErr) || appErr.ErrorCode != quic.ApplicationErrorCode(ErrCodeNoError) { + handleErr = fmt.Errorf("accepting stream failed: %w", err) + } + break } - go s.handleRequest(hconn, str, datagrams, hconn.decoder) + + nextStreamID = str.StreamID() + 4 + wg.Add(1) + go func() { + // handleRequest will return once the request has been handled, + // or the underlying connection is closed + defer wg.Done() + s.handleRequest(hconn, str, datagrams, hconn.decoder) + }() } + wg.Wait() + return handleErr } func (s *Server) maxHeaderBytes() uint64 { @@ -606,7 +684,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat if logger == nil { logger = slog.Default() } - logger.Error("http: panic serving", "arg", p, "trace", buf) + logger.Error("http: panic serving", "arg", p, "trace", string(buf)) } }() handler.ServeHTTP(r, req) @@ -643,11 +721,17 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. +// It is the caller's responsibility to close any connection passed to ServeQUICConn. func (s *Server) Close() error { s.mutex.Lock() defer s.mutex.Unlock() s.closed = true + // server is never used + if s.closeCtx == nil { + return nil + } + s.closeCancel() var err error for ln := range s.listeners { @@ -655,14 +739,44 @@ func (s *Server) Close() error { err = cerr } } + if s.connCount.Load() == 0 { + return err + } + // wait for all connections to be closed + <-s.connHandlingDone return err } -// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete. -// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. -func (s *Server) CloseGracefully(timeout time.Duration) error { - // TODO: implement - return nil +// Shutdown shuts down the server gracefully. +// The server sends a GOAWAY frame first, then or for all running requests to complete. +// Shutdown in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. +func (s *Server) Shutdown(ctx context.Context) error { + s.mutex.Lock() + s.closed = true + // server is never used + if s.closeCtx == nil { + s.mutex.Unlock() + return nil + } + s.graceCancel() + s.mutex.Unlock() + + if s.connCount.Load() == 0 { + return s.Close() + } + select { + case <-s.connHandlingDone: // all connections were closed + // When receiving a GOAWAY frame, HTTP/3 clients are expected to close the connection + // once all requests were successfully handled... + return s.Close() + case <-ctx.Done(): + // ... however, clients handling long-lived requests (and misbehaving clients), + // might not do so before the context is cancelled. + // In this case, we close the server, which closes all existing connections + // (expect those passed to ServeQUICConn). + _ = s.Close() + return ctx.Err() + } } // ErrNoAltSvcPort is the error returned by SetQUICHeaders when no port was found @@ -690,11 +804,6 @@ func (s *Server) SetQUICHeaders(hdr http.Header) error { return nil } -// Deprecated: use SetQUICHeaders instead. -func (s *Server) SetQuicHeaders(hdr http.Header) error { - return s.SetQUICHeaders(hdr) -} - // ListenAndServeQUIC listens on the UDP network address addr and calls the // handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is // used when handler is nil. @@ -706,11 +815,6 @@ func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) er return server.ListenAndServeTLS(certFile, keyFile) } -// Deprecated: use ListenAndServeTLS instead. -func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error { - return ListenAndServeTLS(addr, certFile, keyFile, handler) -} - // ListenAndServeTLS listens on the given network address for both TLS/TCP and QUIC // connections in parallel. It returns if one of the two returns an error. // http.DefaultServeMux is used when handler is nil. diff --git a/vendor/github.com/quic-go/quic-go/http3/roundtrip.go b/vendor/github.com/quic-go/quic-go/http3/transport.go similarity index 63% rename from vendor/github.com/quic-go/quic-go/http3/roundtrip.go rename to vendor/github.com/quic-go/quic-go/http3/transport.go index a9b169ee..8dcaef4d 100644 --- a/vendor/github.com/quic-go/quic-go/http3/roundtrip.go +++ b/vendor/github.com/quic-go/quic-go/http3/transport.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/http" "strings" @@ -30,7 +31,7 @@ type Settings struct { // RoundTripOpt are options for the Transport.RoundTripOpt method. type RoundTripOpt struct { - // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. + // OnlyCachedConn controls whether the Transport may create a new QUIC connection. // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. OnlyCachedConn bool } @@ -59,10 +60,8 @@ func (r *roundTripperWithCount) Close() error { return nil } -// RoundTripper implements the http.RoundTripper interface -type RoundTripper struct { - mutex sync.Mutex - +// Transport implements the http.RoundTripper interface +type Transport struct { // TLSClientConfig specifies the TLS configuration to use with // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config @@ -97,6 +96,13 @@ type RoundTripper struct { // However, if the user explicitly requested gzip it is not automatically uncompressed. DisableCompression bool + StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error) + UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) + + Logger *slog.Logger + + mutex sync.Mutex + initOnce sync.Once initErr error @@ -107,18 +113,56 @@ type RoundTripper struct { } var ( - _ http.RoundTripper = &RoundTripper{} - _ io.Closer = &RoundTripper{} + _ http.RoundTripper = &Transport{} + _ io.Closer = &Transport{} ) -// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set +// Deprecated: RoundTripper was renamed to Transport. +type RoundTripper = Transport + +// ErrNoCachedConn is returned when Transport.OnlyCachedConn is set var ErrNoCachedConn = errors.New("http3: no cached connection was available") +func (t *Transport) init() error { + if t.newClient == nil { + t.newClient = func(conn quic.EarlyConnection) singleRoundTripper { + return newClientConn( + conn, + t.EnableDatagrams, + t.AdditionalSettings, + t.StreamHijacker, + t.UniStreamHijacker, + t.MaxResponseHeaderBytes, + t.DisableCompression, + t.Logger, + ) + } + } + if t.QUICConfig == nil { + t.QUICConfig = defaultQuicConfig.Clone() + t.QUICConfig.EnableDatagrams = t.EnableDatagrams + } + if t.EnableDatagrams && !t.QUICConfig.EnableDatagrams { + return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") + } + if len(t.QUICConfig.Versions) == 0 { + t.QUICConfig = t.QUICConfig.Clone() + t.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]} + } + if len(t.QUICConfig.Versions) != 1 { + return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") + } + if t.QUICConfig.MaxIncomingStreams == 0 { + t.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams + } + return nil +} + // RoundTripOpt is like RoundTrip, but takes options. -func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { - r.initOnce.Do(func() { r.initErr = r.init() }) - if r.initErr != nil { - return nil, r.initErr +func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + t.initOnce.Do(func() { t.initErr = t.init() }) + if t.initErr != nil { + return nil, t.initErr } if req.URL == nil { @@ -154,7 +198,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. } hostname := authorityAddr(hostnameFromURL(req.URL)) - cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn) + cl, isReused, err := t.getClient(req.Context(), hostname, opt.OnlyCachedConn) if err != nil { return nil, err } @@ -166,7 +210,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. } if cl.dialErr != nil { - r.removeClient(hostname) + t.removeClient(hostname) return nil, cl.dialErr } defer cl.useCount.Add(-1) @@ -176,12 +220,12 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. // so we remove the client from the cache so that subsequent trips reconnect // context cancelation is excluded as is does not signify a connection error if !errors.Is(err, context.Canceled) { - r.removeClient(hostname) + t.removeClient(hostname) } if isReused { if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - return r.RoundTripOpt(req, opt) + return t.RoundTripOpt(req, opt) } } } @@ -189,51 +233,19 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. } // RoundTrip does a round trip. -func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return r.RoundTripOpt(req, RoundTripOpt{}) +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripOpt(req, RoundTripOpt{}) } -func (r *RoundTripper) init() error { - if r.newClient == nil { - r.newClient = func(conn quic.EarlyConnection) singleRoundTripper { - return &SingleDestinationRoundTripper{ - Connection: conn, - EnableDatagrams: r.EnableDatagrams, - DisableCompression: r.DisableCompression, - AdditionalSettings: r.AdditionalSettings, - MaxResponseHeaderBytes: r.MaxResponseHeaderBytes, - } - } - } - if r.QUICConfig == nil { - r.QUICConfig = defaultQuicConfig.Clone() - r.QUICConfig.EnableDatagrams = r.EnableDatagrams - } - if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams { - return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") - } - if len(r.QUICConfig.Versions) == 0 { - r.QUICConfig = r.QUICConfig.Clone() - r.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]} - } - if len(r.QUICConfig.Versions) != 1 { - return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") - } - if r.QUICConfig.MaxIncomingStreams == 0 { - r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams - } - return nil -} +func (t *Transport) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) { + t.mutex.Lock() + defer t.mutex.Unlock() -func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) { - r.mutex.Lock() - defer r.mutex.Unlock() - - if r.clients == nil { - r.clients = make(map[string]*roundTripperWithCount) + if t.clients == nil { + t.clients = make(map[string]*roundTripperWithCount) } - cl, ok := r.clients[hostname] + cl, ok := t.clients[hostname] if !ok { if onlyCached { return nil, false, ErrNoCachedConn @@ -246,7 +258,7 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache go func() { defer close(cl.dialing) defer cancel() - conn, rt, err := r.dial(ctx, hostname) + conn, rt, err := t.dial(ctx, hostname) if err != nil { cl.dialErr = err return @@ -254,12 +266,12 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache cl.conn = conn cl.rt = rt }() - r.clients[hostname] = cl + t.clients[hostname] = cl } select { case <-cl.dialing: if cl.dialErr != nil { - delete(r.clients, hostname) + delete(t.clients, hostname) return nil, false, cl.dialErr } select { @@ -273,12 +285,12 @@ func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCache return cl, isReused, nil } -func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) { +func (t *Transport) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) { var tlsConf *tls.Config - if r.TLSClientConfig == nil { + if t.TLSClientConfig == nil { tlsConf = &tls.Config{} } else { - tlsConf = r.TLSClientConfig.Clone() + tlsConf = t.TLSClientConfig.Clone() } if tlsConf.ServerName == "" { sni, _, err := net.SplitHostPort(hostname) @@ -289,61 +301,79 @@ func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyCon tlsConf.ServerName = sni } // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])} + tlsConf.NextProtos = []string{versionToALPN(t.QUICConfig.Versions[0])} - dial := r.Dial + dial := t.Dial if dial == nil { - if r.transport == nil { + if t.transport == nil { udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, nil, err } - r.transport = &quic.Transport{Conn: udpConn} + t.transport = &quic.Transport{Conn: udpConn} } dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) + return t.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) } } - conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig) + conn, err := dial(ctx, hostname, tlsConf, t.QUICConfig) if err != nil { return nil, nil, err } - return conn, r.newClient(conn), nil + return conn, t.newClient(conn), nil } -func (r *RoundTripper) removeClient(hostname string) { - r.mutex.Lock() - defer r.mutex.Unlock() - if r.clients == nil { +func (t *Transport) removeClient(hostname string) { + t.mutex.Lock() + defer t.mutex.Unlock() + if t.clients == nil { return } - delete(r.clients, hostname) + delete(t.clients, hostname) } -// Close closes the QUIC connections that this RoundTripper has used. -// It also closes the underlying UDPConn if it is not nil. -func (r *RoundTripper) Close() error { - r.mutex.Lock() - defer r.mutex.Unlock() - for _, cl := range r.clients { +// NewClientConn creates a new HTTP/3 client connection on top of a QUIC connection. +// Most users should use RoundTrip instead of creating a connection directly. +// Specifically, it is not needed to perform GET, POST, HEAD and CONNECT requests. +// +// Obtaining a ClientConn is only needed for more advanced use cases, such as +// using Extended CONNECT for WebTransport or the various MASQUE protocols. +func (t *Transport) NewClientConn(conn quic.Connection) *ClientConn { + return newClientConn( + conn, + t.EnableDatagrams, + t.AdditionalSettings, + t.StreamHijacker, + t.UniStreamHijacker, + t.MaxResponseHeaderBytes, + t.DisableCompression, + t.Logger, + ) +} + +// Close closes the QUIC connections that this Transport has used. +func (t *Transport) Close() error { + t.mutex.Lock() + defer t.mutex.Unlock() + for _, cl := range t.clients { if err := cl.Close(); err != nil { return err } } - r.clients = nil - if r.transport != nil { - if err := r.transport.Close(); err != nil { + t.clients = nil + if t.transport != nil { + if err := t.transport.Close(); err != nil { return err } - if err := r.transport.Conn.Close(); err != nil { + if err := t.transport.Conn.Close(); err != nil { return err } - r.transport = nil + t.transport = nil } return nil } @@ -376,13 +406,17 @@ func isNotToken(r rune) bool { return !httpguts.IsTokenRune(r) } -func (r *RoundTripper) CloseIdleConnections() { - r.mutex.Lock() - defer r.mutex.Unlock() - for hostname, cl := range r.clients { +// CloseIdleConnections closes any QUIC connections in the transport's pool that are currently idle. +// An idle connection is one that was previously used for requests but is now sitting unused. +// This method does not interrupt any connections currently in use. +// It also does not affect connections obtained via NewClientConn. +func (t *Transport) CloseIdleConnections() { + t.mutex.Lock() + defer t.mutex.Unlock() + for hostname, cl := range t.clients { if cl.useCount.Load() == 0 { cl.Close() - delete(r.clients, hostname) + delete(t.clients, hostname) } } } diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/interface.go b/vendor/github.com/quic-go/quic-go/internal/handshake/interface.go index 6e0eec77..c3a59fcd 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/interface.go @@ -83,6 +83,29 @@ const ( EventHandshakeComplete ) +func (k EventKind) String() string { + switch k { + case EventNoEvent: + return "EventNoEvent" + case EventWriteInitialData: + return "EventWriteInitialData" + case EventWriteHandshakeData: + return "EventWriteHandshakeData" + case EventReceivedReadKeys: + return "EventReceivedReadKeys" + case EventDiscard0RTTKeys: + return "EventDiscard0RTTKeys" + case EventReceivedTransportParameters: + return "EventReceivedTransportParameters" + case EventRestoredTransportParameters: + return "EventRestoredTransportParameters" + case EventHandshakeComplete: + return "EventHandshakeComplete" + default: + return "Unknown EventKind" + } +} + // Event is a handshake event. type Event struct { Kind EventKind diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go index 24e1f64c..dcfac67d 100644 --- a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go @@ -27,11 +27,6 @@ type RTTStats struct { maxAckDelay time.Duration } -// NewRTTStats makes a properly initialized RTTStats object -func NewRTTStats() *RTTStats { - return &RTTStats{} -} - // MinRTT Returns the minRTT for the entire connection. // May return Zero if no valid updates have occurred. func (r *RTTStats) MinRTT() time.Duration { return r.minRTT } @@ -113,19 +108,3 @@ func (r *RTTStats) SetInitialRTT(t time.Duration) { r.smoothedRTT = t r.latestRTT = t } - -// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset. -func (r *RTTStats) OnConnectionMigration() { - r.latestRTT = 0 - r.minRTT = 0 - r.smoothedRTT = 0 - r.meanDeviation = 0 -} - -// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt -// is larger. The mean deviation is increased to the most recent deviation if -// it's larger. -func (r *RTTStats) ExpireSmoothedMetrics() { - r.meanDeviation = max(r.meanDeviation, (r.smoothedRTT - r.latestRTT).Abs()) - r.smoothedRTT = max(r.smoothedRTT, r.latestRTT) -} diff --git a/vendor/github.com/quic-go/quic-go/receive_stream.go b/vendor/github.com/quic-go/quic-go/receive_stream.go index 80340923..b8535ef5 100644 --- a/vendor/github.com/quic-go/quic-go/receive_stream.go +++ b/vendor/github.com/quic-go/quic-go/receive_stream.go @@ -253,6 +253,9 @@ func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) (queuedNe if s.cancelledLocally { // duplicate call to CancelRead return false } + if s.closeForShutdownErr != nil { + return false + } s.cancelledLocally = true if s.errorRead || s.cancelledRemotely { return false diff --git a/vendor/github.com/quic-go/quic-go/send_stream.go b/vendor/github.com/quic-go/quic-go/send_stream.go index bcaf2abf..699c40ef 100644 --- a/vendor/github.com/quic-go/quic-go/send_stream.go +++ b/vendor/github.com/quic-go/quic-go/send_stream.go @@ -423,6 +423,10 @@ func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { s.mutex.Lock() + if s.closeForShutdownErr != nil { + s.mutex.Unlock() + return + } if !remote { s.cancellationFlagged = true if s.cancelWriteErr != nil { diff --git a/vendor/modules.txt b/vendor/modules.txt index b9f6c089..4afc2ac0 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -104,7 +104,7 @@ github.com/powerman/deepequal # github.com/quic-go/qpack v0.5.1 ## explicit; go 1.22 github.com/quic-go/qpack -# github.com/quic-go/quic-go v0.47.0 +# github.com/quic-go/quic-go v0.48.1 ## explicit; go 1.22 github.com/quic-go/quic-go github.com/quic-go/quic-go/http3