diff --git a/http3/client.go b/http3/client.go index 35dd82e6..c63505e1 100644 --- a/http3/client.go +++ b/http3/client.go @@ -9,6 +9,7 @@ import ( "net/http" "strconv" "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -63,7 +64,7 @@ type client struct { decoder *qpack.Decoder hostname string - conn quic.EarlyConnection + conn atomic.Pointer[quic.EarlyConnection] logger utils.Logger } @@ -108,33 +109,35 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con func (c *client) dial(ctx context.Context) error { var err error + var conn quic.EarlyConnection if c.dialer != nil { - c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) + conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) } else { - c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) + conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) } if err != nil { return err } + c.conn.Store(&conn) // send the SETTINGs frame, using 0-RTT data, if possible go func() { - if err := c.setupConn(); err != nil { + if err := c.setupConn(conn); err != nil { c.logger.Debugf("Setting up connection failed: %s", err) - c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") } }() if c.opts.StreamHijacker != nil { - go c.handleBidirectionalStreams() + go c.handleBidirectionalStreams(conn) } - go c.handleUnidirectionalStreams() + go c.handleUnidirectionalStreams(conn) return nil } -func (c *client) setupConn() error { +func (c *client) setupConn(conn quic.EarlyConnection) error { // open the control stream - str, err := c.conn.OpenUniStream() + str, err := conn.OpenUniStream() if err != nil { return err } @@ -146,16 +149,16 @@ func (c *client) setupConn() error { return err } -func (c *client) handleBidirectionalStreams() { +func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { for { - str, err := c.conn.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) if err != nil { c.logger.Debugf("accepting bidirectional stream failed: %s", err) return } go func(str quic.Stream) { _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { - return c.opts.StreamHijacker(ft, c.conn, str, e) + return c.opts.StreamHijacker(ft, conn, str, e) }) if err == errHijacked { return @@ -163,14 +166,14 @@ func (c *client) handleBidirectionalStreams() { if err != nil { c.logger.Debugf("error handling stream: %s", err) } - c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") }(str) } } -func (c *client) handleUnidirectionalStreams() { +func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { for { - str, err := c.conn.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) if err != nil { c.logger.Debugf("accepting unidirectional stream failed: %s", err) return @@ -179,7 +182,7 @@ func (c *client) handleUnidirectionalStreams() { go func(str quic.ReceiveStream) { streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { - if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) { + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) { return } c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) @@ -194,10 +197,10 @@ func (c *client) handleUnidirectionalStreams() { return case streamTypePushStream: // We never increased the Push ID, so we don't expect any push streams. - c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") return default: - if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) { + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) { return } str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) @@ -205,12 +208,12 @@ func (c *client) handleUnidirectionalStreams() { } f, err := parseNextFrame(str, nil) if err != nil { - c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") return } sf, ok := f.(*settingsFrame) if !ok { - c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") return } if !sf.Datagram { @@ -219,18 +222,19 @@ func (c *client) handleUnidirectionalStreams() { // If datagram support was enabled on our side as well as on the server side, // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams { - c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams { + conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") } }(str) } } func (c *client) Close() error { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return nil } - return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "") + return (*conn).CloseWithError(quic.ApplicationErrorCode(errorNoError), "") } func (c *client) maxHeaderBytes() uint64 { @@ -249,24 +253,26 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon c.dialOnce.Do(func() { c.handshakeErr = c.dial(req.Context()) }) - if c.handshakeErr != nil { return nil, c.handshakeErr } + // At this point, c.conn is guaranteed to be set. + conn := *c.conn.Load() + // Immediately send out this request, if this is a 0-RTT request. if req.Method == MethodGet0RTT { req.Method = http.MethodGet } else { // wait for the handshake to complete select { - case <-c.conn.HandshakeComplete().Done(): + case <-conn.HandshakeComplete().Done(): case <-req.Context().Done(): return nil, req.Context().Err() } } - str, err := c.conn.OpenStreamSync(req.Context()) + str, err := conn.OpenStreamSync(req.Context()) if err != nil { return nil, err } @@ -290,7 +296,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon if opt.DontCloseRequestStream { doneChan = nil } - rsp, rerr := c.doRequest(req, str, opt, doneChan) + rsp, rerr := c.doRequest(req, conn, str, opt, doneChan) if rerr.err != nil { // if any error occurred close(reqDone) <-done @@ -302,7 +308,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon if rerr.err != nil { reason = rerr.err.Error() } - c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) + conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } return nil, rerr.err } @@ -340,7 +346,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error { return nil } -func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { +func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { var requestGzip bool if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true @@ -353,7 +359,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, str.Close() } - hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) if req.Body != nil { // send the request body asynchronously go func() { @@ -387,7 +393,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, return nil, newConnError(errorGeneralProtocolError, err) } - connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS) + connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS) res := &http.Response{ Proto: "HTTP/3.0", ProtoMajor: 3, @@ -408,7 +414,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, res.Header.Add(hf.Name, hf.Value) } } - respBody := newResponseBody(hstr, c.conn, reqDone) + respBody := newResponseBody(hstr, conn, reqDone) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. _, hasTransferEncoding := res.Header["Transfer-Encoding"] @@ -438,11 +444,12 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, } func (c *client) HandshakeComplete() bool { - if c.conn == nil { + conn := c.conn.Load() + if conn == nil { return false } select { - case <-c.conn.HandshakeComplete().Done(): + case <-(*conn).HandshakeComplete().Done(): return true default: return false diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index adb0d772..db67c47e 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -18,6 +18,7 @@ import ( "github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/testdata" + "golang.org/x/sync/errgroup" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -121,6 +122,27 @@ var _ = Describe("HTTP tests", func() { Expect(string(body)).To(Equal("Hello, World!\n")) }) + It("downloads concurrently", func() { + group, ctx := errgroup.WithContext(context.Background()) + for i := 0; i < 2; i++ { + group.Go(func() error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://localhost:"+port+"/hello", nil) + Expect(err).ToNot(HaveOccurred()) + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World!\n")) + + return nil + }) + } + + err := group.Wait() + Expect(err).ToNot(HaveOccurred()) + }) + It("sets and gets request headers", func() { handlerCalled := make(chan struct{}) mux.HandleFunc("/headers/request", func(w http.ResponseWriter, r *http.Request) {