http3: correctly handle closed clients (#3684)

* http3: use a mock roundTripCloser in tests

* http3: correctly handle failed clients

Specifically,
* immediately remove a client when a request errored
* if that error was an idle error, and the client was a reused client
(from an earlier request that already completed the handshake),
re-dial the connection
This commit is contained in:
Marten Seemann 2023-01-28 00:49:52 -08:00 committed by GitHub
parent 7b2c69451e
commit 89769f409f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 336 additions and 147 deletions

View file

@ -68,7 +68,9 @@ type client struct {
logger utils.Logger
}
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) {
var _ roundTripCloser = &client{}
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
if conf == nil {
conf = defaultQuicConfig.Clone()
} else if len(conf.Versions) == 0 {
@ -434,3 +436,15 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
return res, requestError{}
}
func (c *client) HandshakeComplete() bool {
if c.conn == nil {
return false
}
select {
case <-c.conn.HandshakeComplete().Done():
return true
default:
return false
}
}

View file

@ -26,7 +26,7 @@ import (
var _ = Describe("Client", func() {
var (
client *client
cl *client
req *http.Request
origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context
@ -35,10 +35,10 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
var err error
client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(client.hostname).To(Equal(hostname))
cl = c.(*client)
Expect(cl.hostname).To(Equal(hostname))
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
@ -168,7 +168,7 @@ var _ = Describe("Client", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTripOpt(req, RoundTripOpt{})
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
@ -179,7 +179,7 @@ var _ = Describe("Client", func() {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr
}
_, err = client.RoundTripOpt(req, RoundTripOpt{})
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
})
@ -220,7 +220,7 @@ var _ = Describe("Client", func() {
It("hijacks a bidirectional stream of unknown frame type", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return true, nil
@ -235,7 +235,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(request, RoundTripOpt{})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -243,7 +243,7 @@ var _ = Describe("Client", func() {
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, nil
@ -259,14 +259,14 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
It("closes the connection when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, errors.New("error in hijacker")
@ -282,7 +282,7 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
@ -291,7 +291,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error")
unknownStr := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{})
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
defer close(done)
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
@ -306,7 +306,7 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -348,7 +348,7 @@ var _ = Describe("Client", func() {
It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
@ -365,7 +365,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -375,7 +375,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
@ -389,7 +389,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -397,7 +397,7 @@ var _ = Describe("Client", func() {
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
@ -415,7 +415,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -467,7 +467,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@ -492,7 +492,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
})
@ -515,7 +515,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -539,7 +539,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorMissingSettings))
close(done)
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -563,7 +563,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorFrameError))
close(done)
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -586,13 +586,13 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorIDError))
close(done)
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
It("errors when the server advertises datagram support (and we enabled support for it)", func() {
client.opts.EnableDatagram = true
cl.opts.EnableDatagram = true
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{Datagram: true}).Append(b)
r := bytes.NewReader(b)
@ -613,7 +613,7 @@ var _ = Describe("Client", func() {
Expect(reason).To(Equal("missing QUIC Datagram support"))
close(done)
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -705,7 +705,7 @@ var _ = Describe("Client", func() {
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
@ -721,7 +721,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
return 0, testErr
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
})
@ -736,7 +736,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
@ -753,7 +753,7 @@ var _ = Describe("Client", func() {
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
@ -788,7 +788,7 @@ var _ = Describe("Client", func() {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
@ -812,7 +812,7 @@ var _ = Describe("Client", func() {
})
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Eventually(closed).Should(BeClosed())
})
@ -831,7 +831,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
req, err := client.RoundTripOpt(req, RoundTripOpt{})
req, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
Expect(req.ContentLength).To(BeEquivalentTo(1337))
Eventually(done).Should(BeClosed())
@ -844,7 +844,7 @@ var _ = Describe("Client", func() {
r := bytes.NewReader(b)
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
Eventually(closed).Should(BeClosed())
})
@ -856,7 +856,7 @@ var _ = Describe("Client", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
Eventually(closed).Should(BeClosed())
})
@ -876,7 +876,7 @@ var _ = Describe("Client", func() {
errChan := make(chan error)
go func() {
_, err := client.RoundTripOpt(req, roundTripOpt)
_, err := cl.RoundTripOpt(req, roundTripOpt)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
@ -906,7 +906,7 @@ var _ = Describe("Client", func() {
<-canceled
return 0, errors.New("test done")
})
_, err := client.RoundTripOpt(req, roundTripOpt)
_, err := cl.RoundTripOpt(req, roundTripOpt)
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
@ -929,7 +929,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(done).Should(BeClosed())
@ -950,7 +950,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
@ -989,7 +989,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
@ -1012,7 +1012,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())

View file

@ -0,0 +1,78 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: roundtrip.go
// Package http3 is a generated GoMock package.
package http3
import (
http "net/http"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockRoundTripCloser is a mock of RoundTripCloser interface.
type MockRoundTripCloser struct {
ctrl *gomock.Controller
recorder *MockRoundTripCloserMockRecorder
}
// MockRoundTripCloserMockRecorder is the mock recorder for MockRoundTripCloser.
type MockRoundTripCloserMockRecorder struct {
mock *MockRoundTripCloser
}
// NewMockRoundTripCloser creates a new mock instance.
func NewMockRoundTripCloser(ctrl *gomock.Controller) *MockRoundTripCloser {
mock := &MockRoundTripCloser{ctrl: ctrl}
mock.recorder = &MockRoundTripCloserMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRoundTripCloser) EXPECT() *MockRoundTripCloserMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockRoundTripCloser) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockRoundTripCloserMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoundTripCloser)(nil).Close))
}
// HandshakeComplete mocks base method.
func (m *MockRoundTripCloser) HandshakeComplete() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandshakeComplete")
ret0, _ := ret[0].(bool)
return ret0
}
// HandshakeComplete indicates an expected call of HandshakeComplete.
func (mr *MockRoundTripCloserMockRecorder) HandshakeComplete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockRoundTripCloser)(nil).HandshakeComplete))
}
// RoundTripOpt mocks base method.
func (m *MockRoundTripCloser) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt) (*http.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RoundTripOpt", arg0, arg1)
ret0, _ := ret[0].(*http.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RoundTripOpt indicates an expected call of RoundTripOpt.
func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1)
}

View file

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
@ -17,6 +18,7 @@ import (
type roundTripCloser interface {
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
HandshakeComplete() bool
io.Closer
}
@ -75,7 +77,8 @@ type RoundTripper struct {
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
clients map[string]roundTripCloser
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
clients map[string]roundTripCloser
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
@ -131,11 +134,20 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
}
hostname := authorityAddr("https", hostnameFromRequest(req))
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
return cl.RoundTripOpt(req, opt)
rsp, err := cl.RoundTripOpt(req, opt)
if err != nil {
r.removeClient(hostname)
if isReused {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return r.RoundTripOpt(req, opt)
}
}
}
return rsp, err
}
// RoundTrip does a round trip.
@ -143,7 +155,7 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) {
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) {
r.mutex.Lock()
defer r.mutex.Unlock()
@ -154,10 +166,14 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo
client, ok := r.clients[hostname]
if !ok {
if onlyCached {
return nil, ErrNoCachedConn
return nil, false, ErrNoCachedConn
}
var err error
client, err = newClient(
newCl := newClient
if r.newClient != nil {
newCl = r.newClient
}
client, err = newCl(
hostname,
r.TLSClientConfig,
&roundTripperOpts{
@ -171,11 +187,22 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo
r.Dial,
)
if err != nil {
return nil, err
return nil, false, err
}
r.clients[hostname] = client
} else if client.HandshakeComplete() {
isReused = true
}
return client, nil
return client, isReused, nil
}
func (r *RoundTripper) removeClient(hostname string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
return
}
delete(r.clients, hostname)
}
// Close closes the QUIC connections that this RoundTripper has used

View file

@ -10,27 +10,14 @@ import (
"time"
"github.com/quic-go/quic-go"
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
type mockClient struct {
closed bool
}
func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
return &http.Response{Request: req}, nil
}
func (m *mockClient) Close() error {
m.closed = true
return nil
}
var _ roundTripCloser = &mockClient{}
//go:generate sh -c "./../mockgen_private.sh http3 mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 roundTripCloser"
type mockBody struct {
reader bytes.Reader
@ -60,57 +47,29 @@ func (m *mockBody) Close() error {
var _ = Describe("RoundTripper", func() {
var (
rt *RoundTripper
req1 *http.Request
conn *mockquic.MockEarlyConnection
handshakeCtx context.Context // an already canceled context
rt *RoundTripper
req *http.Request
)
BeforeEach(func() {
rt = &RoundTripper{}
var err error
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
})
Context("dialing hosts", func() {
origDialAddr := dialAddr
BeforeEach(func() {
conn = mockquic.NewMockEarlyConnection(mockCtrl)
origDialAddr = dialAddr
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
// return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all
return conn, nil
}
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("creates new clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err")
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-closed
return nil, errors.New("test done")
}).MaxTimes(1)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
return cl, nil
}
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1))
Eventually(closed).Should(BeClosed())
})
It("uses the quic.Config, if provided", func() {
@ -121,7 +80,7 @@ var _ = Describe("RoundTripper", func() {
return nil, errors.New("handshake error")
}
rt.QuicConfig = config
_, err := rt.RoundTrip(req1)
_, err := rt.RoundTrip(req)
Expect(err).To(MatchError("handshake error"))
Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout))
})
@ -133,33 +92,144 @@ var _ = Describe("RoundTripper", func() {
return nil, errors.New("handshake error")
}
rt.Dial = dialer
_, err := rt.RoundTrip(req1)
_, err := rt.RoundTrip(req)
Expect(err).To(MatchError("handshake error"))
Expect(dialed).To(BeTrue())
})
})
Context("reusing clients", func() {
var req1, req2 *http.Request
BeforeEach(func() {
var err error
req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
Expect(err).ToNot(HaveOccurred())
Expect(req1.URL).ToNot(Equal(req2.URL))
})
It("reuses existing clients", func() {
closed := make(chan struct{})
var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
return &http.Response{Request: req}, nil
}).Times(2)
cl.EXPECT().HandshakeComplete().Return(true)
return cl, nil
}
rsp1, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp1.Request.URL).To(Equal(req1.URL))
rsp2, err := rt.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp2.Request.URL).To(Equal(req2.URL))
Expect(count).To(Equal(1))
})
It("immediately removes a clients when a request errored", func() {
testErr := errors.New("test err")
conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2)
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-closed
return nil, errors.New("test done")
}).MaxTimes(1)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
return cl, nil
}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1))
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req2)
Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1))
Eventually(closed).Should(BeClosed())
Expect(count).To(Equal(2))
})
It("recreates a client when a request times out", func() {
var reqCount int
cl1 := NewMockRoundTripCloser(mockCtrl)
cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
reqCount++
if reqCount == 1 { // the first request is successful...
Expect(req.URL).To(Equal(req1.URL))
return &http.Response{Request: req}, nil
}
// ... after that, the connection timed out in the background
Expect(req.URL).To(Equal(req2.URL))
return nil, &qerr.IdleTimeoutError{}
}).Times(2)
cl1.EXPECT().HandshakeComplete().Return(true)
cl2 := NewMockRoundTripCloser(mockCtrl)
cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
return &http.Response{Request: req}, nil
})
var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
if count == 1 {
return cl1, nil
}
return cl2, nil
}
rsp1, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
rsp2, err := rt.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
})
It("only issues a request once, even if a timeout error occurs", func() {
var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
return cl, nil
}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
Expect(count).To(Equal(1))
})
It("handles a burst of requests", func() {
wait := make(chan struct{})
reqs := make(chan struct{}, 2)
var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
reqs <- struct{}{}
<-wait
return nil, &qerr.IdleTimeoutError{}
}).Times(2)
cl.EXPECT().HandshakeComplete()
return cl, nil
}
done := make(chan struct{}, 2)
go func() {
defer GinkgoRecover()
defer func() { done <- struct{}{} }()
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
}()
go func() {
defer GinkgoRecover()
defer func() { done <- struct{}{} }()
_, err := rt.RoundTrip(req2)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
}()
// wait for both requests to be issued
Eventually(reqs).Should(Receive())
Eventually(reqs).Should(Receive())
close(wait) // now return the requests
Eventually(done).Should(Receive())
Eventually(done).Should(Receive())
Expect(count).To(Equal(1))
})
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
@ -181,66 +251,66 @@ var _ = Describe("RoundTripper", func() {
})
It("rejects requests without a URL", func() {
req1.URL = nil
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
req.URL = nil
req.Body = &mockBody{}
_, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects request without a URL Host", func() {
req1.URL.Host = ""
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
req.URL.Host = ""
req.Body = &mockBody{}
_, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: no Host in request URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
It("doesn't try to close the body if the request doesn't have one", func() {
req1.URL = nil
Expect(req1.Body).To(BeNil())
_, err := rt.RoundTrip(req1)
req.URL = nil
Expect(req.Body).To(BeNil())
_, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.URL"))
})
It("rejects requests without a header", func() {
req1.Header = nil
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
req.Header = nil
req.Body = &mockBody{}
_, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.Header"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
It("rejects requests with invalid header name fields", func() {
req1.Header.Add("foobär", "value")
_, err := rt.RoundTrip(req1)
req.Header.Add("foobär", "value")
_, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
})
It("rejects requests with invalid header name values", func() {
req1.Header.Add("foo", string([]byte{0x7}))
_, err := rt.RoundTrip(req1)
req.Header.Add("foo", string([]byte{0x7}))
_, err := rt.RoundTrip(req)
Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
})
It("rejects requests with an invalid request method", func() {
req1.Method = "foobär"
req1.Body = &mockBody{}
_, err := rt.RoundTrip(req1)
req.Method = "foobär"
req.Body = &mockBody{}
_, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: invalid method \"foobär\""))
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
Expect(req.Body.(*mockBody).closed).To(BeTrue())
})
})
Context("closing", func() {
It("closes", func() {
rt.clients = make(map[string]roundTripCloser)
cl := &mockClient{}
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().Close()
rt.clients["foo.bar"] = cl
err := rt.Close()
Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero())
Expect(cl.closed).To(BeTrue())
})
It("closes a RoundTripper that has never been used", func() {