http3: correctly handle closed clients (#3684)

* http3: use a mock roundTripCloser in tests

* http3: correctly handle failed clients

Specifically,
* immediately remove a client when a request errored
* if that error was an idle error, and the client was a reused client
(from an earlier request that already completed the handshake),
re-dial the connection
This commit is contained in:
Marten Seemann 2023-01-28 00:49:52 -08:00 committed by GitHub
parent 7b2c69451e
commit 89769f409f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 336 additions and 147 deletions

View file

@ -68,7 +68,9 @@ type client struct {
logger utils.Logger logger utils.Logger
} }
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) { var _ roundTripCloser = &client{}
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
if conf == nil { if conf == nil {
conf = defaultQuicConfig.Clone() conf = defaultQuicConfig.Clone()
} else if len(conf.Versions) == 0 { } else if len(conf.Versions) == 0 {
@ -434,3 +436,15 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
return res, requestError{} return res, requestError{}
} }
func (c *client) HandshakeComplete() bool {
if c.conn == nil {
return false
}
select {
case <-c.conn.HandshakeComplete().Done():
return true
default:
return false
}
}

View file

@ -26,7 +26,7 @@ import (
var _ = Describe("Client", func() { var _ = Describe("Client", func() {
var ( var (
client *client cl *client
req *http.Request req *http.Request
origDialAddr = dialAddr origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context handshakeCtx context.Context // an already canceled context
@ -35,10 +35,10 @@ var _ = Describe("Client", func() {
BeforeEach(func() { BeforeEach(func() {
origDialAddr = dialAddr origDialAddr = dialAddr
hostname := "quic.clemente.io:1337" hostname := "quic.clemente.io:1337"
var err error c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(client.hostname).To(Equal(hostname)) cl = c.(*client)
Expect(cl.hostname).To(Equal(hostname))
req, err = http.NewRequest("GET", "https://localhost:1337", nil) req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -168,7 +168,7 @@ var _ = Describe("Client", func() {
It("refuses to do requests for the wrong host", func() { It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTripOpt(req, RoundTripOpt{}) _, err = cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
}) })
@ -179,7 +179,7 @@ var _ = Describe("Client", func() {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr return nil, testErr
} }
_, err = client.RoundTripOpt(req, RoundTripOpt{}) _, err = cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
}) })
@ -220,7 +220,7 @@ var _ = Describe("Client", func() {
It("hijacks a bidirectional stream of unknown frame type", func() { It("hijacks a bidirectional stream of unknown frame type", func() {
frameTypeChan := make(chan FrameType, 1) frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred()) Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft frameTypeChan <- ft
return true, nil return true, nil
@ -235,7 +235,7 @@ var _ = Describe("Client", func() {
<-testDone <-testDone
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
_, err := client.RoundTripOpt(request, RoundTripOpt{}) _, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -243,7 +243,7 @@ var _ = Describe("Client", func() {
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1) frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred()) Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft frameTypeChan <- ft
return false, nil return false, nil
@ -259,14 +259,14 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{}) _, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
}) })
It("closes the connection when hijacker returned error", func() { It("closes the connection when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1) frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred()) Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft frameTypeChan <- ft
return false, errors.New("error in hijacker") return false, errors.New("error in hijacker")
@ -282,7 +282,7 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{}) _, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
}) })
@ -291,7 +291,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error") testErr := errors.New("test error")
unknownStr := mockquic.NewMockStream(mockCtrl) unknownStr := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{}) done := make(chan struct{})
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
defer close(done) defer close(done)
Expect(e).To(MatchError(testErr)) Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero()) Expect(ft).To(BeZero())
@ -306,7 +306,7 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{}) _, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -348,7 +348,7 @@ var _ = Describe("Client", func() {
It("hijacks an unidirectional stream of unknown stream type", func() { It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1) streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st streamTypeChan <- st
return true return true
@ -365,7 +365,7 @@ var _ = Describe("Client", func() {
<-testDone <-testDone
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -375,7 +375,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error") testErr := errors.New("test error")
done := make(chan struct{}) done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl) unknownStr := mockquic.NewMockStream(mockCtrl)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done) defer close(done)
Expect(st).To(BeZero()) Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr)) Expect(str).To(Equal(unknownStr))
@ -389,7 +389,7 @@ var _ = Describe("Client", func() {
<-testDone <-testDone
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -397,7 +397,7 @@ var _ = Describe("Client", func() {
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1) streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st streamTypeChan <- st
return false return false
@ -415,7 +415,7 @@ var _ = Describe("Client", func() {
<-testDone <-testDone
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -467,7 +467,7 @@ var _ = Describe("Client", func() {
<-testDone <-testDone
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
}) })
@ -492,7 +492,7 @@ var _ = Describe("Client", func() {
<-testDone <-testDone
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
}) })
@ -515,7 +515,7 @@ var _ = Describe("Client", func() {
<-testDone <-testDone
return nil, errors.New("test done") return nil, errors.New("test done")
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -539,7 +539,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorMissingSettings)) Expect(code).To(BeEquivalentTo(errorMissingSettings))
close(done) close(done)
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -563,7 +563,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorFrameError)) Expect(code).To(BeEquivalentTo(errorFrameError))
close(done) close(done)
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -586,13 +586,13 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorIDError)) Expect(code).To(BeEquivalentTo(errorIDError))
close(done) close(done)
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("errors when the server advertises datagram support (and we enabled support for it)", func() { It("errors when the server advertises datagram support (and we enabled support for it)", func() {
client.opts.EnableDatagram = true cl.opts.EnableDatagram = true
b := quicvarint.Append(nil, streamTypeControlStream) b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{Datagram: true}).Append(b) b = (&settingsFrame{Datagram: true}).Append(b)
r := bytes.NewReader(b) r := bytes.NewReader(b)
@ -613,7 +613,7 @@ var _ = Describe("Client", func() {
Expect(reason).To(Equal("missing QUIC Datagram support")) Expect(reason).To(Equal("missing QUIC Datagram support"))
close(done) close(done)
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done")) Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -705,7 +705,7 @@ var _ = Describe("Client", func() {
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx) conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
}) })
@ -721,7 +721,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
return 0, testErr return 0, testErr
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
}) })
@ -736,7 +736,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close() str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0")) Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3)) Expect(rsp.ProtoMajor).To(Equal(3))
@ -753,7 +753,7 @@ var _ = Describe("Client", func() {
) )
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0")) Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3)) Expect(rsp.ProtoMajor).To(Equal(3))
@ -788,7 +788,7 @@ var _ = Describe("Client", func() {
<-done <-done
return 0, errors.New("test done") return 0, errors.New("test done")
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done")) Expect(err).To(MatchError("test done"))
hfs := decodeHeader(strBuf) hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST")) Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
@ -812,7 +812,7 @@ var _ = Describe("Client", func() {
}) })
closed := make(chan struct{}) closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Close().Do(func() { close(closed) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done")) Expect(err).To(MatchError("test done"))
Eventually(closed).Should(BeClosed()) Eventually(closed).Should(BeClosed())
}) })
@ -831,7 +831,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
// the response body is sent asynchronously, while already reading the response // the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
req, err := client.RoundTripOpt(req, RoundTripOpt{}) req, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(req.ContentLength).To(BeEquivalentTo(1337)) Expect(req.ContentLength).To(BeEquivalentTo(1337))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -844,7 +844,7 @@ var _ = Describe("Client", func() {
r := bytes.NewReader(b) r := bytes.NewReader(b)
str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
Eventually(closed).Should(BeClosed()) Eventually(closed).Should(BeClosed())
}) })
@ -856,7 +856,7 @@ var _ = Describe("Client", func() {
closed := make(chan struct{}) closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
Eventually(closed).Should(BeClosed()) Eventually(closed).Should(BeClosed())
}) })
@ -876,7 +876,7 @@ var _ = Describe("Client", func() {
errChan := make(chan error) errChan := make(chan error)
go func() { go func() {
_, err := client.RoundTripOpt(req, roundTripOpt) _, err := cl.RoundTripOpt(req, roundTripOpt)
errChan <- err errChan <- err
}() }()
Consistently(errChan).ShouldNot(Receive()) Consistently(errChan).ShouldNot(Receive())
@ -906,7 +906,7 @@ var _ = Describe("Client", func() {
<-canceled <-canceled
return 0, errors.New("test done") return 0, errors.New("test done")
}) })
_, err := client.RoundTripOpt(req, roundTripOpt) _, err := cl.RoundTripOpt(req, roundTripOpt)
Expect(err).To(MatchError("test done")) Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
@ -929,7 +929,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)) str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
cancel() cancel()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
@ -950,7 +950,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
) )
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTripOpt(req, RoundTripOpt{}) _, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done")) Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf) hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
@ -989,7 +989,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close() str.EXPECT().Close()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body) data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -1012,7 +1012,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close() str.EXPECT().Close()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{}) rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body) data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -0,0 +1,78 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: roundtrip.go
// Package http3 is a generated GoMock package.
package http3
import (
http "net/http"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockRoundTripCloser is a mock of RoundTripCloser interface.
type MockRoundTripCloser struct {
ctrl *gomock.Controller
recorder *MockRoundTripCloserMockRecorder
}
// MockRoundTripCloserMockRecorder is the mock recorder for MockRoundTripCloser.
type MockRoundTripCloserMockRecorder struct {
mock *MockRoundTripCloser
}
// NewMockRoundTripCloser creates a new mock instance.
func NewMockRoundTripCloser(ctrl *gomock.Controller) *MockRoundTripCloser {
mock := &MockRoundTripCloser{ctrl: ctrl}
mock.recorder = &MockRoundTripCloserMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRoundTripCloser) EXPECT() *MockRoundTripCloserMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockRoundTripCloser) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockRoundTripCloserMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoundTripCloser)(nil).Close))
}
// HandshakeComplete mocks base method.
func (m *MockRoundTripCloser) HandshakeComplete() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandshakeComplete")
ret0, _ := ret[0].(bool)
return ret0
}
// HandshakeComplete indicates an expected call of HandshakeComplete.
func (mr *MockRoundTripCloserMockRecorder) HandshakeComplete() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockRoundTripCloser)(nil).HandshakeComplete))
}
// RoundTripOpt mocks base method.
func (m *MockRoundTripCloser) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt) (*http.Response, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RoundTripOpt", arg0, arg1)
ret0, _ := ret[0].(*http.Response)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RoundTripOpt indicates an expected call of RoundTripOpt.
func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1)
}

View file

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@ -17,6 +18,7 @@ import (
type roundTripCloser interface { type roundTripCloser interface {
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
HandshakeComplete() bool
io.Closer io.Closer
} }
@ -75,7 +77,8 @@ type RoundTripper struct {
// Zero means to use a default limit. // Zero means to use a default limit.
MaxResponseHeaderBytes int64 MaxResponseHeaderBytes int64
clients map[string]roundTripCloser newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
clients map[string]roundTripCloser
} }
// RoundTripOpt are options for the Transport.RoundTripOpt method. // RoundTripOpt are options for the Transport.RoundTripOpt method.
@ -131,11 +134,20 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
} }
hostname := authorityAddr("https", hostnameFromRequest(req)) hostname := authorityAddr("https", hostnameFromRequest(req))
cl, err := r.getClient(hostname, opt.OnlyCachedConn) cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return cl.RoundTripOpt(req, opt) rsp, err := cl.RoundTripOpt(req, opt)
if err != nil {
r.removeClient(hostname)
if isReused {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return r.RoundTripOpt(req, opt)
}
}
}
return rsp, err
} }
// RoundTrip does a round trip. // RoundTrip does a round trip.
@ -143,7 +155,7 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{}) return r.RoundTripOpt(req, RoundTripOpt{})
} }
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) { func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
@ -154,10 +166,14 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
if !ok { if !ok {
if onlyCached { if onlyCached {
return nil, ErrNoCachedConn return nil, false, ErrNoCachedConn
} }
var err error var err error
client, err = newClient( newCl := newClient
if r.newClient != nil {
newCl = r.newClient
}
client, err = newCl(
hostname, hostname,
r.TLSClientConfig, r.TLSClientConfig,
&roundTripperOpts{ &roundTripperOpts{
@ -171,11 +187,22 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo
r.Dial, r.Dial,
) )
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
r.clients[hostname] = client r.clients[hostname] = client
} else if client.HandshakeComplete() {
isReused = true
} }
return client, nil return client, isReused, nil
}
func (r *RoundTripper) removeClient(hostname string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
return
}
delete(r.clients, hostname)
} }
// Close closes the QUIC connections that this RoundTripper has used // Close closes the QUIC connections that this RoundTripper has used

View file

@ -10,27 +10,14 @@ import (
"time" "time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
mockquic "github.com/quic-go/quic-go/internal/mocks/quic" "github.com/quic-go/quic-go/internal/qerr"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type mockClient struct { //go:generate sh -c "./../mockgen_private.sh http3 mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 roundTripCloser"
closed bool
}
func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
return &http.Response{Request: req}, nil
}
func (m *mockClient) Close() error {
m.closed = true
return nil
}
var _ roundTripCloser = &mockClient{}
type mockBody struct { type mockBody struct {
reader bytes.Reader reader bytes.Reader
@ -60,57 +47,29 @@ func (m *mockBody) Close() error {
var _ = Describe("RoundTripper", func() { var _ = Describe("RoundTripper", func() {
var ( var (
rt *RoundTripper rt *RoundTripper
req1 *http.Request req *http.Request
conn *mockquic.MockEarlyConnection
handshakeCtx context.Context // an already canceled context
) )
BeforeEach(func() { BeforeEach(func() {
rt = &RoundTripper{} rt = &RoundTripper{}
var err error var err error
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
}) })
Context("dialing hosts", func() { Context("dialing hosts", func() {
origDialAddr := dialAddr
BeforeEach(func() {
conn = mockquic.NewMockEarlyConnection(mockCtrl)
origDialAddr = dialAddr
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
// return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all
return conn, nil
}
})
AfterEach(func() {
dialAddr = origDialAddr
})
It("creates new clients", func() { It("creates new clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err") testErr := errors.New("test err")
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
conn.EXPECT().HandshakeComplete().Return(handshakeCtx) cl := NewMockRoundTripCloser(mockCtrl)
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return cl, nil
<-closed }
return nil, errors.New("test done")
}).MaxTimes(1)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
_, err = rt.RoundTrip(req) _, err = rt.RoundTrip(req)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1))
Eventually(closed).Should(BeClosed())
}) })
It("uses the quic.Config, if provided", func() { It("uses the quic.Config, if provided", func() {
@ -121,7 +80,7 @@ var _ = Describe("RoundTripper", func() {
return nil, errors.New("handshake error") return nil, errors.New("handshake error")
} }
rt.QuicConfig = config rt.QuicConfig = config
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err).To(MatchError("handshake error")) Expect(err).To(MatchError("handshake error"))
Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout)) Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout))
}) })
@ -133,33 +92,144 @@ var _ = Describe("RoundTripper", func() {
return nil, errors.New("handshake error") return nil, errors.New("handshake error")
} }
rt.Dial = dialer rt.Dial = dialer
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err).To(MatchError("handshake error")) Expect(err).To(MatchError("handshake error"))
Expect(dialed).To(BeTrue()) Expect(dialed).To(BeTrue())
}) })
})
Context("reusing clients", func() {
var req1, req2 *http.Request
BeforeEach(func() {
var err error
req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
Expect(err).ToNot(HaveOccurred())
req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
Expect(err).ToNot(HaveOccurred())
Expect(req1.URL).ToNot(Equal(req2.URL))
})
It("reuses existing clients", func() { It("reuses existing clients", func() {
closed := make(chan struct{}) var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
return &http.Response{Request: req}, nil
}).Times(2)
cl.EXPECT().HandshakeComplete().Return(true)
return cl, nil
}
rsp1, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp1.Request.URL).To(Equal(req1.URL))
rsp2, err := rt.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp2.Request.URL).To(Equal(req2.URL))
Expect(count).To(Equal(1))
})
It("immediately removes a clients when a request errored", func() {
testErr := errors.New("test err") testErr := errors.New("test err")
conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2) var count int
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { count++
<-closed cl := NewMockRoundTripCloser(mockCtrl)
return nil, errors.New("test done") cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
}).MaxTimes(1) return cl, nil
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) }
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) _, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1))
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = rt.RoundTrip(req2) _, err = rt.RoundTrip(req2)
Expect(err).To(MatchError(testErr)) Expect(err).To(MatchError(testErr))
Expect(rt.clients).To(HaveLen(1)) Expect(count).To(Equal(2))
Eventually(closed).Should(BeClosed()) })
It("recreates a client when a request times out", func() {
var reqCount int
cl1 := NewMockRoundTripCloser(mockCtrl)
cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
reqCount++
if reqCount == 1 { // the first request is successful...
Expect(req.URL).To(Equal(req1.URL))
return &http.Response{Request: req}, nil
}
// ... after that, the connection timed out in the background
Expect(req.URL).To(Equal(req2.URL))
return nil, &qerr.IdleTimeoutError{}
}).Times(2)
cl1.EXPECT().HandshakeComplete().Return(true)
cl2 := NewMockRoundTripCloser(mockCtrl)
cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
return &http.Response{Request: req}, nil
})
var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
if count == 1 {
return cl1, nil
}
return cl2, nil
}
rsp1, err := rt.RoundTrip(req1)
Expect(err).ToNot(HaveOccurred())
Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
rsp2, err := rt.RoundTrip(req2)
Expect(err).ToNot(HaveOccurred())
Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
})
It("only issues a request once, even if a timeout error occurs", func() {
var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
return cl, nil
}
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
Expect(count).To(Equal(1))
})
It("handles a burst of requests", func() {
wait := make(chan struct{})
reqs := make(chan struct{}, 2)
var count int
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
count++
cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
reqs <- struct{}{}
<-wait
return nil, &qerr.IdleTimeoutError{}
}).Times(2)
cl.EXPECT().HandshakeComplete()
return cl, nil
}
done := make(chan struct{}, 2)
go func() {
defer GinkgoRecover()
defer func() { done <- struct{}{} }()
_, err := rt.RoundTrip(req1)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
}()
go func() {
defer GinkgoRecover()
defer func() { done <- struct{}{} }()
_, err := rt.RoundTrip(req2)
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
}()
// wait for both requests to be issued
Eventually(reqs).Should(Receive())
Eventually(reqs).Should(Receive())
close(wait) // now return the requests
Eventually(done).Should(Receive())
Eventually(done).Should(Receive())
Expect(count).To(Equal(1))
}) })
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
@ -181,66 +251,66 @@ var _ = Describe("RoundTripper", func() {
}) })
It("rejects requests without a URL", func() { It("rejects requests without a URL", func() {
req1.URL = nil req.URL = nil
req1.Body = &mockBody{} req.Body = &mockBody{}
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.URL")) Expect(err).To(MatchError("http3: nil Request.URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req.Body.(*mockBody).closed).To(BeTrue())
}) })
It("rejects request without a URL Host", func() { It("rejects request without a URL Host", func() {
req1.URL.Host = "" req.URL.Host = ""
req1.Body = &mockBody{} req.Body = &mockBody{}
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: no Host in request URL")) Expect(err).To(MatchError("http3: no Host in request URL"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req.Body.(*mockBody).closed).To(BeTrue())
}) })
It("doesn't try to close the body if the request doesn't have one", func() { It("doesn't try to close the body if the request doesn't have one", func() {
req1.URL = nil req.URL = nil
Expect(req1.Body).To(BeNil()) Expect(req.Body).To(BeNil())
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.URL")) Expect(err).To(MatchError("http3: nil Request.URL"))
}) })
It("rejects requests without a header", func() { It("rejects requests without a header", func() {
req1.Header = nil req.Header = nil
req1.Body = &mockBody{} req.Body = &mockBody{}
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: nil Request.Header")) Expect(err).To(MatchError("http3: nil Request.Header"))
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req.Body.(*mockBody).closed).To(BeTrue())
}) })
It("rejects requests with invalid header name fields", func() { It("rejects requests with invalid header name fields", func() {
req1.Header.Add("foobär", "value") req.Header.Add("foobär", "value")
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: invalid http header field name \"foobär\"")) Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
}) })
It("rejects requests with invalid header name values", func() { It("rejects requests with invalid header name values", func() {
req1.Header.Add("foo", string([]byte{0x7})) req.Header.Add("foo", string([]byte{0x7}))
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value")) Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
}) })
It("rejects requests with an invalid request method", func() { It("rejects requests with an invalid request method", func() {
req1.Method = "foobär" req.Method = "foobär"
req1.Body = &mockBody{} req.Body = &mockBody{}
_, err := rt.RoundTrip(req1) _, err := rt.RoundTrip(req)
Expect(err).To(MatchError("http3: invalid method \"foobär\"")) Expect(err).To(MatchError("http3: invalid method \"foobär\""))
Expect(req1.Body.(*mockBody).closed).To(BeTrue()) Expect(req.Body.(*mockBody).closed).To(BeTrue())
}) })
}) })
Context("closing", func() { Context("closing", func() {
It("closes", func() { It("closes", func() {
rt.clients = make(map[string]roundTripCloser) rt.clients = make(map[string]roundTripCloser)
cl := &mockClient{} cl := NewMockRoundTripCloser(mockCtrl)
cl.EXPECT().Close()
rt.clients["foo.bar"] = cl rt.clients["foo.bar"] = cl
err := rt.Close() err := rt.Close()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(rt.clients)).To(BeZero()) Expect(len(rt.clients)).To(BeZero())
Expect(cl.closed).To(BeTrue())
}) })
It("closes a RoundTripper that has never been used", func() { It("closes a RoundTripper that has never been used", func() {