http3: fix race condition when accessing the client's connection (#3696)

* http3: fix race condition when accessing the client's connection

* add an integration test for concurrent HTTP requests

---------

Co-authored-by: Bulat Khasanov <afti@yandex.ru>
This commit is contained in:
Marten Seemann 2023-02-14 11:54:09 +13:00 committed by GitHub
parent aa091fe672
commit e0d4ffffef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 35 deletions

View file

@ -9,6 +9,7 @@ import (
"net/http"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
@ -63,7 +64,7 @@ type client struct {
decoder *qpack.Decoder
hostname string
conn quic.EarlyConnection
conn atomic.Pointer[quic.EarlyConnection]
logger utils.Logger
}
@ -108,33 +109,35 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con
func (c *client) dial(ctx context.Context) error {
var err error
var conn quic.EarlyConnection
if c.dialer != nil {
c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
} else {
c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
c.conn.Store(&conn)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(); err != nil {
if err := c.setupConn(conn); err != nil {
c.logger.Debugf("Setting up connection failed: %s", err)
c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
}
}()
if c.opts.StreamHijacker != nil {
go c.handleBidirectionalStreams()
go c.handleBidirectionalStreams(conn)
}
go c.handleUnidirectionalStreams()
go c.handleUnidirectionalStreams(conn)
return nil
}
func (c *client) setupConn() error {
func (c *client) setupConn(conn quic.EarlyConnection) error {
// open the control stream
str, err := c.conn.OpenUniStream()
str, err := conn.OpenUniStream()
if err != nil {
return err
}
@ -146,16 +149,16 @@ func (c *client) setupConn() error {
return err
}
func (c *client) handleBidirectionalStreams() {
func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
for {
str, err := c.conn.AcceptStream(context.Background())
str, err := conn.AcceptStream(context.Background())
if err != nil {
c.logger.Debugf("accepting bidirectional stream failed: %s", err)
return
}
go func(str quic.Stream) {
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str, e)
return c.opts.StreamHijacker(ft, conn, str, e)
})
if err == errHijacked {
return
@ -163,14 +166,14 @@ func (c *client) handleBidirectionalStreams() {
if err != nil {
c.logger.Debugf("error handling stream: %s", err)
}
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}(str)
}
}
func (c *client) handleUnidirectionalStreams() {
func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
for {
str, err := c.conn.AcceptUniStream(context.Background())
str, err := conn.AcceptUniStream(context.Background())
if err != nil {
c.logger.Debugf("accepting unidirectional stream failed: %s", err)
return
@ -179,7 +182,7 @@ func (c *client) handleUnidirectionalStreams() {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
return
}
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
@ -194,10 +197,10 @@ func (c *client) handleUnidirectionalStreams() {
return
case streamTypePushStream:
// We never increased the Push ID, so we don't expect any push streams.
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
@ -205,12 +208,12 @@ func (c *client) handleUnidirectionalStreams() {
}
f, err := parseNextFrame(str, nil)
if err != nil {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
return
}
sf, ok := f.(*settingsFrame)
if !ok {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
return
}
if !sf.Datagram {
@ -219,18 +222,19 @@ func (c *client) handleUnidirectionalStreams() {
// If datagram support was enabled on our side as well as on the server side,
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
}
}(str)
}
}
func (c *client) Close() error {
if c.conn == nil {
conn := c.conn.Load()
if conn == nil {
return nil
}
return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
return (*conn).CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
}
func (c *client) maxHeaderBytes() uint64 {
@ -249,24 +253,26 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
c.dialOnce.Do(func() {
c.handshakeErr = c.dial(req.Context())
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
// At this point, c.conn is guaranteed to be set.
conn := *c.conn.Load()
// Immediately send out this request, if this is a 0-RTT request.
if req.Method == MethodGet0RTT {
req.Method = http.MethodGet
} else {
// wait for the handshake to complete
select {
case <-c.conn.HandshakeComplete().Done():
case <-conn.HandshakeComplete().Done():
case <-req.Context().Done():
return nil, req.Context().Err()
}
}
str, err := c.conn.OpenStreamSync(req.Context())
str, err := conn.OpenStreamSync(req.Context())
if err != nil {
return nil, err
}
@ -290,7 +296,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
if opt.DontCloseRequestStream {
doneChan = nil
}
rsp, rerr := c.doRequest(req, str, opt, doneChan)
rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
if rerr.err != nil { // if any error occurred
close(reqDone)
<-done
@ -302,7 +308,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
if rerr.err != nil {
reason = rerr.err.Error()
}
c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
}
return nil, rerr.err
}
@ -340,7 +346,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
return nil
}
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
var requestGzip bool
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
requestGzip = true
@ -353,7 +359,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
str.Close()
}
hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
if req.Body != nil {
// send the request body asynchronously
go func() {
@ -387,7 +393,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
return nil, newConnError(errorGeneralProtocolError, err)
}
connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS)
connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS)
res := &http.Response{
Proto: "HTTP/3.0",
ProtoMajor: 3,
@ -408,7 +414,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
res.Header.Add(hf.Name, hf.Value)
}
}
respBody := newResponseBody(hstr, c.conn, reqDone)
respBody := newResponseBody(hstr, conn, reqDone)
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
@ -438,11 +444,12 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
}
func (c *client) HandshakeComplete() bool {
if c.conn == nil {
conn := c.conn.Load()
if conn == nil {
return false
}
select {
case <-c.conn.HandshakeComplete().Done():
case <-(*conn).HandshakeComplete().Done():
return true
default:
return false

View file

@ -18,6 +18,7 @@ import (
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/testdata"
"golang.org/x/sync/errgroup"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@ -121,6 +122,27 @@ var _ = Describe("HTTP tests", func() {
Expect(string(body)).To(Equal("Hello, World!\n"))
})
It("downloads concurrently", func() {
group, ctx := errgroup.WithContext(context.Background())
for i := 0; i < 2; i++ {
group.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://localhost:"+port+"/hello", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := client.Do(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World!\n"))
return nil
})
}
err := group.Wait()
Expect(err).ToNot(HaveOccurred())
})
It("sets and gets request headers", func() {
handlerCalled := make(chan struct{})
mux.HandleFunc("/headers/request", func(w http.ResponseWriter, r *http.Request) {