mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
http3: fix race condition when accessing the client's connection (#3696)
* http3: fix race condition when accessing the client's connection * add an integration test for concurrent HTTP requests --------- Co-authored-by: Bulat Khasanov <afti@yandex.ru>
This commit is contained in:
parent
aa091fe672
commit
e0d4ffffef
2 changed files with 64 additions and 35 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
|
@ -63,7 +64,7 @@ type client struct {
|
|||
decoder *qpack.Decoder
|
||||
|
||||
hostname string
|
||||
conn quic.EarlyConnection
|
||||
conn atomic.Pointer[quic.EarlyConnection]
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -108,33 +109,35 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con
|
|||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
var err error
|
||||
var conn quic.EarlyConnection
|
||||
if c.dialer != nil {
|
||||
c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
|
||||
conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
|
||||
} else {
|
||||
c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
|
||||
conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.conn.Store(&conn)
|
||||
|
||||
// send the SETTINGs frame, using 0-RTT data, if possible
|
||||
go func() {
|
||||
if err := c.setupConn(); err != nil {
|
||||
if err := c.setupConn(conn); err != nil {
|
||||
c.logger.Debugf("Setting up connection failed: %s", err)
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
|
||||
}
|
||||
}()
|
||||
|
||||
if c.opts.StreamHijacker != nil {
|
||||
go c.handleBidirectionalStreams()
|
||||
go c.handleBidirectionalStreams(conn)
|
||||
}
|
||||
go c.handleUnidirectionalStreams()
|
||||
go c.handleUnidirectionalStreams(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) setupConn() error {
|
||||
func (c *client) setupConn(conn quic.EarlyConnection) error {
|
||||
// open the control stream
|
||||
str, err := c.conn.OpenUniStream()
|
||||
str, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -146,16 +149,16 @@ func (c *client) setupConn() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (c *client) handleBidirectionalStreams() {
|
||||
func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
|
||||
for {
|
||||
str, err := c.conn.AcceptStream(context.Background())
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
c.logger.Debugf("accepting bidirectional stream failed: %s", err)
|
||||
return
|
||||
}
|
||||
go func(str quic.Stream) {
|
||||
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
|
||||
return c.opts.StreamHijacker(ft, c.conn, str, e)
|
||||
return c.opts.StreamHijacker(ft, conn, str, e)
|
||||
})
|
||||
if err == errHijacked {
|
||||
return
|
||||
|
@ -163,14 +166,14 @@ func (c *client) handleBidirectionalStreams() {
|
|||
if err != nil {
|
||||
c.logger.Debugf("error handling stream: %s", err)
|
||||
}
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handleUnidirectionalStreams() {
|
||||
func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
|
||||
for {
|
||||
str, err := c.conn.AcceptUniStream(context.Background())
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
c.logger.Debugf("accepting unidirectional stream failed: %s", err)
|
||||
return
|
||||
|
@ -179,7 +182,7 @@ func (c *client) handleUnidirectionalStreams() {
|
|||
go func(str quic.ReceiveStream) {
|
||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||
if err != nil {
|
||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
|
||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
|
||||
return
|
||||
}
|
||||
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||
|
@ -194,10 +197,10 @@ func (c *client) handleUnidirectionalStreams() {
|
|||
return
|
||||
case streamTypePushStream:
|
||||
// We never increased the Push ID, so we don't expect any push streams.
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
|
||||
return
|
||||
default:
|
||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
|
||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
|
||||
return
|
||||
}
|
||||
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
|
||||
|
@ -205,12 +208,12 @@ func (c *client) handleUnidirectionalStreams() {
|
|||
}
|
||||
f, err := parseNextFrame(str, nil)
|
||||
if err != nil {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
|
||||
return
|
||||
}
|
||||
sf, ok := f.(*settingsFrame)
|
||||
if !ok {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
|
||||
return
|
||||
}
|
||||
if !sf.Datagram {
|
||||
|
@ -219,18 +222,19 @@ func (c *client) handleUnidirectionalStreams() {
|
|||
// If datagram support was enabled on our side as well as on the server side,
|
||||
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
|
||||
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
|
||||
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
|
||||
if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
|
||||
}
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
if c.conn == nil {
|
||||
conn := c.conn.Load()
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
|
||||
return (*conn).CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
|
||||
}
|
||||
|
||||
func (c *client) maxHeaderBytes() uint64 {
|
||||
|
@ -249,24 +253,26 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
|
|||
c.dialOnce.Do(func() {
|
||||
c.handshakeErr = c.dial(req.Context())
|
||||
})
|
||||
|
||||
if c.handshakeErr != nil {
|
||||
return nil, c.handshakeErr
|
||||
}
|
||||
|
||||
// At this point, c.conn is guaranteed to be set.
|
||||
conn := *c.conn.Load()
|
||||
|
||||
// Immediately send out this request, if this is a 0-RTT request.
|
||||
if req.Method == MethodGet0RTT {
|
||||
req.Method = http.MethodGet
|
||||
} else {
|
||||
// wait for the handshake to complete
|
||||
select {
|
||||
case <-c.conn.HandshakeComplete().Done():
|
||||
case <-conn.HandshakeComplete().Done():
|
||||
case <-req.Context().Done():
|
||||
return nil, req.Context().Err()
|
||||
}
|
||||
}
|
||||
|
||||
str, err := c.conn.OpenStreamSync(req.Context())
|
||||
str, err := conn.OpenStreamSync(req.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -290,7 +296,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
|
|||
if opt.DontCloseRequestStream {
|
||||
doneChan = nil
|
||||
}
|
||||
rsp, rerr := c.doRequest(req, str, opt, doneChan)
|
||||
rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
|
||||
if rerr.err != nil { // if any error occurred
|
||||
close(reqDone)
|
||||
<-done
|
||||
|
@ -302,7 +308,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
|
|||
if rerr.err != nil {
|
||||
reason = rerr.err.Error()
|
||||
}
|
||||
c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
|
||||
}
|
||||
return nil, rerr.err
|
||||
}
|
||||
|
@ -340,7 +346,7 @@ func (c *client) sendRequestBody(str Stream, body io.ReadCloser) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
|
||||
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
|
||||
var requestGzip bool
|
||||
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
||||
requestGzip = true
|
||||
|
@ -353,7 +359,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
|
|||
str.Close()
|
||||
}
|
||||
|
||||
hstr := newStream(str, func() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
|
||||
hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") })
|
||||
if req.Body != nil {
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
|
@ -387,7 +393,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
|
|||
return nil, newConnError(errorGeneralProtocolError, err)
|
||||
}
|
||||
|
||||
connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS)
|
||||
connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS)
|
||||
res := &http.Response{
|
||||
Proto: "HTTP/3.0",
|
||||
ProtoMajor: 3,
|
||||
|
@ -408,7 +414,7 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
|
|||
res.Header.Add(hf.Name, hf.Value)
|
||||
}
|
||||
}
|
||||
respBody := newResponseBody(hstr, c.conn, reqDone)
|
||||
respBody := newResponseBody(hstr, conn, reqDone)
|
||||
|
||||
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
|
||||
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
|
||||
|
@ -438,11 +444,12 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
|
|||
}
|
||||
|
||||
func (c *client) HandshakeComplete() bool {
|
||||
if c.conn == nil {
|
||||
conn := c.conn.Load()
|
||||
if conn == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-c.conn.HandshakeComplete().Done():
|
||||
case <-(*conn).HandshakeComplete().Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/testdata"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
@ -121,6 +122,27 @@ var _ = Describe("HTTP tests", func() {
|
|||
Expect(string(body)).To(Equal("Hello, World!\n"))
|
||||
})
|
||||
|
||||
It("downloads concurrently", func() {
|
||||
group, ctx := errgroup.WithContext(context.Background())
|
||||
for i := 0; i < 2; i++ {
|
||||
group.Go(func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://localhost:"+port+"/hello", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp, err := client.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(Equal("Hello, World!\n"))
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err := group.Wait()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("sets and gets request headers", func() {
|
||||
handlerCalled := make(chan struct{})
|
||||
mux.HandleFunc("/headers/request", func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue