mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
http3: correctly handle closed clients (#3684)
* http3: use a mock roundTripCloser in tests * http3: correctly handle failed clients Specifically, * immediately remove a client when a request errored * if that error was an idle error, and the client was a reused client (from an earlier request that already completed the handshake), re-dial the connection
This commit is contained in:
parent
7b2c69451e
commit
89769f409f
5 changed files with 336 additions and 147 deletions
|
@ -68,7 +68,9 @@ type client struct {
|
|||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) {
|
||||
var _ roundTripCloser = &client{}
|
||||
|
||||
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
||||
if conf == nil {
|
||||
conf = defaultQuicConfig.Clone()
|
||||
} else if len(conf.Versions) == 0 {
|
||||
|
@ -434,3 +436,15 @@ func (c *client) doRequest(req *http.Request, str quic.Stream, opt RoundTripOpt,
|
|||
|
||||
return res, requestError{}
|
||||
}
|
||||
|
||||
func (c *client) HandshakeComplete() bool {
|
||||
if c.conn == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-c.conn.HandshakeComplete().Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ import (
|
|||
|
||||
var _ = Describe("Client", func() {
|
||||
var (
|
||||
client *client
|
||||
cl *client
|
||||
req *http.Request
|
||||
origDialAddr = dialAddr
|
||||
handshakeCtx context.Context // an already canceled context
|
||||
|
@ -35,10 +35,10 @@ var _ = Describe("Client", func() {
|
|||
BeforeEach(func() {
|
||||
origDialAddr = dialAddr
|
||||
hostname := "quic.clemente.io:1337"
|
||||
var err error
|
||||
client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
|
||||
c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(client.hostname).To(Equal(hostname))
|
||||
cl = c.(*client)
|
||||
Expect(cl.hostname).To(Equal(hostname))
|
||||
|
||||
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -168,7 +168,7 @@ var _ = Describe("Client", func() {
|
|||
It("refuses to do requests for the wrong host", func() {
|
||||
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
|
||||
})
|
||||
|
||||
|
@ -179,7 +179,7 @@ var _ = Describe("Client", func() {
|
|||
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
return nil, testErr
|
||||
}
|
||||
_, err = client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
})
|
||||
|
@ -220,7 +220,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("hijacks a bidirectional stream of unknown frame type", func() {
|
||||
frameTypeChan := make(chan FrameType, 1)
|
||||
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||
Expect(e).ToNot(HaveOccurred())
|
||||
frameTypeChan <- ft
|
||||
return true, nil
|
||||
|
@ -235,7 +235,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(request, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
|
@ -243,7 +243,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
|
||||
frameTypeChan := make(chan FrameType, 1)
|
||||
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||
Expect(e).ToNot(HaveOccurred())
|
||||
frameTypeChan <- ft
|
||||
return false, nil
|
||||
|
@ -259,14 +259,14 @@ var _ = Describe("Client", func() {
|
|||
return nil, errors.New("test done")
|
||||
})
|
||||
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
||||
_, err := client.RoundTripOpt(request, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
})
|
||||
|
||||
It("closes the connection when hijacker returned error", func() {
|
||||
frameTypeChan := make(chan FrameType, 1)
|
||||
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
|
||||
Expect(e).ToNot(HaveOccurred())
|
||||
frameTypeChan <- ft
|
||||
return false, errors.New("error in hijacker")
|
||||
|
@ -282,7 +282,7 @@ var _ = Describe("Client", func() {
|
|||
return nil, errors.New("test done")
|
||||
})
|
||||
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
||||
_, err := client.RoundTripOpt(request, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
|
||||
})
|
||||
|
@ -291,7 +291,7 @@ var _ = Describe("Client", func() {
|
|||
testErr := errors.New("test error")
|
||||
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||
done := make(chan struct{})
|
||||
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
|
||||
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
|
||||
defer close(done)
|
||||
Expect(e).To(MatchError(testErr))
|
||||
Expect(ft).To(BeZero())
|
||||
|
@ -306,7 +306,7 @@ var _ = Describe("Client", func() {
|
|||
return nil, errors.New("test done")
|
||||
})
|
||||
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
|
||||
_, err := client.RoundTripOpt(request, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
|
@ -348,7 +348,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("hijacks an unidirectional stream of unknown stream type", func() {
|
||||
streamTypeChan := make(chan StreamType, 1)
|
||||
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
||||
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
streamTypeChan <- st
|
||||
return true
|
||||
|
@ -365,7 +365,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
|
@ -375,7 +375,7 @@ var _ = Describe("Client", func() {
|
|||
testErr := errors.New("test error")
|
||||
done := make(chan struct{})
|
||||
unknownStr := mockquic.NewMockStream(mockCtrl)
|
||||
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
|
||||
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
|
||||
defer close(done)
|
||||
Expect(st).To(BeZero())
|
||||
Expect(str).To(Equal(unknownStr))
|
||||
|
@ -389,7 +389,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
|
@ -397,7 +397,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
|
||||
streamTypeChan := make(chan StreamType, 1)
|
||||
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
||||
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
streamTypeChan <- st
|
||||
return false
|
||||
|
@ -415,7 +415,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
|
@ -467,7 +467,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
|
||||
})
|
||||
|
@ -492,7 +492,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
|
||||
})
|
||||
|
@ -515,7 +515,7 @@ var _ = Describe("Client", func() {
|
|||
<-testDone
|
||||
return nil, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -539,7 +539,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(code).To(BeEquivalentTo(errorMissingSettings))
|
||||
close(done)
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -563,7 +563,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(code).To(BeEquivalentTo(errorFrameError))
|
||||
close(done)
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -586,13 +586,13 @@ var _ = Describe("Client", func() {
|
|||
Expect(code).To(BeEquivalentTo(errorIDError))
|
||||
close(done)
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors when the server advertises datagram support (and we enabled support for it)", func() {
|
||||
client.opts.EnableDatagram = true
|
||||
cl.opts.EnableDatagram = true
|
||||
b := quicvarint.Append(nil, streamTypeControlStream)
|
||||
b = (&settingsFrame{Datagram: true}).Append(b)
|
||||
r := bytes.NewReader(b)
|
||||
|
@ -613,7 +613,7 @@ var _ = Describe("Client", func() {
|
|||
Expect(reason).To(Equal("missing QUIC Datagram support"))
|
||||
close(done)
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -705,7 +705,7 @@ var _ = Describe("Client", func() {
|
|||
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
||||
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
|
||||
|
@ -721,7 +721,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
|
||||
return 0, testErr
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
|
||||
})
|
||||
|
@ -736,7 +736,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Close()
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
|
||||
Expect(rsp.ProtoMajor).To(Equal(3))
|
||||
|
@ -753,7 +753,7 @@ var _ = Describe("Client", func() {
|
|||
)
|
||||
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
|
||||
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
|
||||
Expect(rsp.ProtoMajor).To(Equal(3))
|
||||
|
@ -788,7 +788,7 @@ var _ = Describe("Client", func() {
|
|||
<-done
|
||||
return 0, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
hfs := decodeHeader(strBuf)
|
||||
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
|
||||
|
@ -812,7 +812,7 @@ var _ = Describe("Client", func() {
|
|||
})
|
||||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
|
@ -831,7 +831,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
|
||||
// the response body is sent asynchronously, while already reading the response
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
||||
req, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
req, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(req.ContentLength).To(BeEquivalentTo(1337))
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -844,7 +844,7 @@ var _ = Describe("Client", func() {
|
|||
r := bytes.NewReader(b)
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
|
@ -856,7 +856,7 @@ var _ = Describe("Client", func() {
|
|||
closed := make(chan struct{})
|
||||
str.EXPECT().Close().Do(func() { close(closed) })
|
||||
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
|
@ -876,7 +876,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
_, err := client.RoundTripOpt(req, roundTripOpt)
|
||||
_, err := cl.RoundTripOpt(req, roundTripOpt)
|
||||
errChan <- err
|
||||
}()
|
||||
Consistently(errChan).ShouldNot(Receive())
|
||||
|
@ -906,7 +906,7 @@ var _ = Describe("Client", func() {
|
|||
<-canceled
|
||||
return 0, errors.New("test done")
|
||||
})
|
||||
_, err := client.RoundTripOpt(req, roundTripOpt)
|
||||
_, err := cl.RoundTripOpt(req, roundTripOpt)
|
||||
Expect(err).To(MatchError("test done"))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
@ -929,7 +929,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
|
||||
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
|
||||
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cancel()
|
||||
Eventually(done).Should(BeClosed())
|
||||
|
@ -950,7 +950,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
|
||||
)
|
||||
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
|
||||
_, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).To(MatchError("test done"))
|
||||
hfs := decodeHeader(buf)
|
||||
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
|
||||
|
@ -989,7 +989,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str.EXPECT().Close()
|
||||
|
||||
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
@ -1012,7 +1012,7 @@ var _ = Describe("Client", func() {
|
|||
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
||||
str.EXPECT().Close()
|
||||
|
||||
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
|
||||
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, err := io.ReadAll(rsp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
|
78
http3/mock_roundtripcloser_test.go
Normal file
78
http3/mock_roundtripcloser_test.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: roundtrip.go
|
||||
|
||||
// Package http3 is a generated GoMock package.
|
||||
package http3
|
||||
|
||||
import (
|
||||
http "net/http"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockRoundTripCloser is a mock of RoundTripCloser interface.
|
||||
type MockRoundTripCloser struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRoundTripCloserMockRecorder
|
||||
}
|
||||
|
||||
// MockRoundTripCloserMockRecorder is the mock recorder for MockRoundTripCloser.
|
||||
type MockRoundTripCloserMockRecorder struct {
|
||||
mock *MockRoundTripCloser
|
||||
}
|
||||
|
||||
// NewMockRoundTripCloser creates a new mock instance.
|
||||
func NewMockRoundTripCloser(ctrl *gomock.Controller) *MockRoundTripCloser {
|
||||
mock := &MockRoundTripCloser{ctrl: ctrl}
|
||||
mock.recorder = &MockRoundTripCloserMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRoundTripCloser) EXPECT() *MockRoundTripCloserMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockRoundTripCloser) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockRoundTripCloserMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoundTripCloser)(nil).Close))
|
||||
}
|
||||
|
||||
// HandshakeComplete mocks base method.
|
||||
func (m *MockRoundTripCloser) HandshakeComplete() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "HandshakeComplete")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// HandshakeComplete indicates an expected call of HandshakeComplete.
|
||||
func (mr *MockRoundTripCloserMockRecorder) HandshakeComplete() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockRoundTripCloser)(nil).HandshakeComplete))
|
||||
}
|
||||
|
||||
// RoundTripOpt mocks base method.
|
||||
func (m *MockRoundTripCloser) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt) (*http.Response, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RoundTripOpt", arg0, arg1)
|
||||
ret0, _ := ret[0].(*http.Response)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// RoundTripOpt indicates an expected call of RoundTripOpt.
|
||||
func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1)
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -17,6 +18,7 @@ import (
|
|||
|
||||
type roundTripCloser interface {
|
||||
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
|
||||
HandshakeComplete() bool
|
||||
io.Closer
|
||||
}
|
||||
|
||||
|
@ -75,7 +77,8 @@ type RoundTripper struct {
|
|||
// Zero means to use a default limit.
|
||||
MaxResponseHeaderBytes int64
|
||||
|
||||
clients map[string]roundTripCloser
|
||||
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
|
||||
clients map[string]roundTripCloser
|
||||
}
|
||||
|
||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
|
@ -131,11 +134,20 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
|
|||
}
|
||||
|
||||
hostname := authorityAddr("https", hostnameFromRequest(req))
|
||||
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
|
||||
cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cl.RoundTripOpt(req, opt)
|
||||
rsp, err := cl.RoundTripOpt(req, opt)
|
||||
if err != nil {
|
||||
r.removeClient(hostname)
|
||||
if isReused {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
return r.RoundTripOpt(req, opt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rsp, err
|
||||
}
|
||||
|
||||
// RoundTrip does a round trip.
|
||||
|
@ -143,7 +155,7 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripCloser, error) {
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
|
@ -154,10 +166,14 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo
|
|||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
if onlyCached {
|
||||
return nil, ErrNoCachedConn
|
||||
return nil, false, ErrNoCachedConn
|
||||
}
|
||||
var err error
|
||||
client, err = newClient(
|
||||
newCl := newClient
|
||||
if r.newClient != nil {
|
||||
newCl = r.newClient
|
||||
}
|
||||
client, err = newCl(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{
|
||||
|
@ -171,11 +187,22 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (roundTripClo
|
|||
r.Dial,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
r.clients[hostname] = client
|
||||
} else if client.HandshakeComplete() {
|
||||
isReused = true
|
||||
}
|
||||
return client, nil
|
||||
return client, isReused, nil
|
||||
}
|
||||
|
||||
func (r *RoundTripper) removeClient(hostname string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
if r.clients == nil {
|
||||
return
|
||||
}
|
||||
delete(r.clients, hostname)
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this RoundTripper has used
|
||||
|
|
|
@ -10,27 +10,14 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockClient) RoundTripOpt(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||
return &http.Response{Request: req}, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ roundTripCloser = &mockClient{}
|
||||
//go:generate sh -c "./../mockgen_private.sh http3 mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 roundTripCloser"
|
||||
|
||||
type mockBody struct {
|
||||
reader bytes.Reader
|
||||
|
@ -60,57 +47,29 @@ func (m *mockBody) Close() error {
|
|||
|
||||
var _ = Describe("RoundTripper", func() {
|
||||
var (
|
||||
rt *RoundTripper
|
||||
req1 *http.Request
|
||||
conn *mockquic.MockEarlyConnection
|
||||
handshakeCtx context.Context // an already canceled context
|
||||
rt *RoundTripper
|
||||
req *http.Request
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
rt = &RoundTripper{}
|
||||
var err error
|
||||
req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
||||
req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
handshakeCtx = ctx
|
||||
})
|
||||
|
||||
Context("dialing hosts", func() {
|
||||
origDialAddr := dialAddr
|
||||
|
||||
BeforeEach(func() {
|
||||
conn = mockquic.NewMockEarlyConnection(mockCtrl)
|
||||
origDialAddr = dialAddr
|
||||
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
|
||||
// return an error when trying to open a stream
|
||||
// we don't want to test all the dial logic here, just that dialing happens at all
|
||||
return conn, nil
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
dialAddr = origDialAddr
|
||||
})
|
||||
|
||||
It("creates new clients", func() {
|
||||
closed := make(chan struct{})
|
||||
testErr := errors.New("test err")
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
|
||||
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
|
||||
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||
<-closed
|
||||
return nil, errors.New("test done")
|
||||
}).MaxTimes(1)
|
||||
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
|
||||
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
|
||||
return cl, nil
|
||||
}
|
||||
_, err = rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("uses the quic.Config, if provided", func() {
|
||||
|
@ -121,7 +80,7 @@ var _ = Describe("RoundTripper", func() {
|
|||
return nil, errors.New("handshake error")
|
||||
}
|
||||
rt.QuicConfig = config
|
||||
_, err := rt.RoundTrip(req1)
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("handshake error"))
|
||||
Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout))
|
||||
})
|
||||
|
@ -133,33 +92,144 @@ var _ = Describe("RoundTripper", func() {
|
|||
return nil, errors.New("handshake error")
|
||||
}
|
||||
rt.Dial = dialer
|
||||
_, err := rt.RoundTrip(req1)
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("handshake error"))
|
||||
Expect(dialed).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("reusing clients", func() {
|
||||
var req1, req2 *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(req1.URL).ToNot(Equal(req2.URL))
|
||||
})
|
||||
|
||||
It("reuses existing clients", func() {
|
||||
closed := make(chan struct{})
|
||||
var count int
|
||||
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
||||
count++
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||
return &http.Response{Request: req}, nil
|
||||
}).Times(2)
|
||||
cl.EXPECT().HandshakeComplete().Return(true)
|
||||
return cl, nil
|
||||
}
|
||||
rsp1, err := rt.RoundTrip(req1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp1.Request.URL).To(Equal(req1.URL))
|
||||
rsp2, err := rt.RoundTrip(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp2.Request.URL).To(Equal(req2.URL))
|
||||
Expect(count).To(Equal(1))
|
||||
})
|
||||
|
||||
It("immediately removes a clients when a request errored", func() {
|
||||
testErr := errors.New("test err")
|
||||
conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
|
||||
conn.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2)
|
||||
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
|
||||
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
|
||||
<-closed
|
||||
return nil, errors.New("test done")
|
||||
}).MaxTimes(1)
|
||||
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) })
|
||||
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req)
|
||||
|
||||
var count int
|
||||
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
||||
count++
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr)
|
||||
return cl, nil
|
||||
}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = rt.RoundTrip(req2)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
Expect(rt.clients).To(HaveLen(1))
|
||||
Eventually(closed).Should(BeClosed())
|
||||
Expect(count).To(Equal(2))
|
||||
})
|
||||
|
||||
It("recreates a client when a request times out", func() {
|
||||
var reqCount int
|
||||
cl1 := NewMockRoundTripCloser(mockCtrl)
|
||||
cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||
reqCount++
|
||||
if reqCount == 1 { // the first request is successful...
|
||||
Expect(req.URL).To(Equal(req1.URL))
|
||||
return &http.Response{Request: req}, nil
|
||||
}
|
||||
// ... after that, the connection timed out in the background
|
||||
Expect(req.URL).To(Equal(req2.URL))
|
||||
return nil, &qerr.IdleTimeoutError{}
|
||||
}).Times(2)
|
||||
cl1.EXPECT().HandshakeComplete().Return(true)
|
||||
cl2 := NewMockRoundTripCloser(mockCtrl)
|
||||
cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||
return &http.Response{Request: req}, nil
|
||||
})
|
||||
|
||||
var count int
|
||||
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
||||
count++
|
||||
if count == 1 {
|
||||
return cl1, nil
|
||||
}
|
||||
return cl2, nil
|
||||
}
|
||||
rsp1, err := rt.RoundTrip(req1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr))
|
||||
rsp2, err := rt.RoundTrip(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr))
|
||||
})
|
||||
|
||||
It("only issues a request once, even if a timeout error occurs", func() {
|
||||
var count int
|
||||
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
||||
count++
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{})
|
||||
return cl, nil
|
||||
}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
||||
Expect(count).To(Equal(1))
|
||||
})
|
||||
|
||||
It("handles a burst of requests", func() {
|
||||
wait := make(chan struct{})
|
||||
reqs := make(chan struct{}, 2)
|
||||
var count int
|
||||
rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) {
|
||||
count++
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) {
|
||||
reqs <- struct{}{}
|
||||
<-wait
|
||||
return nil, &qerr.IdleTimeoutError{}
|
||||
}).Times(2)
|
||||
cl.EXPECT().HandshakeComplete()
|
||||
return cl, nil
|
||||
}
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer func() { done <- struct{}{} }()
|
||||
_, err := rt.RoundTrip(req1)
|
||||
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
||||
}()
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
defer func() { done <- struct{}{} }()
|
||||
_, err := rt.RoundTrip(req2)
|
||||
Expect(err).To(MatchError(&qerr.IdleTimeoutError{}))
|
||||
}()
|
||||
// wait for both requests to be issued
|
||||
Eventually(reqs).Should(Receive())
|
||||
Eventually(reqs).Should(Receive())
|
||||
close(wait) // now return the requests
|
||||
Eventually(done).Should(Receive())
|
||||
Eventually(done).Should(Receive())
|
||||
Expect(count).To(Equal(1))
|
||||
})
|
||||
|
||||
It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() {
|
||||
|
@ -181,66 +251,66 @@ var _ = Describe("RoundTripper", func() {
|
|||
})
|
||||
|
||||
It("rejects requests without a URL", func() {
|
||||
req1.URL = nil
|
||||
req1.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
req.URL = nil
|
||||
req.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: nil Request.URL"))
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects request without a URL Host", func() {
|
||||
req1.URL.Host = ""
|
||||
req1.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
req.URL.Host = ""
|
||||
req.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: no Host in request URL"))
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("doesn't try to close the body if the request doesn't have one", func() {
|
||||
req1.URL = nil
|
||||
Expect(req1.Body).To(BeNil())
|
||||
_, err := rt.RoundTrip(req1)
|
||||
req.URL = nil
|
||||
Expect(req.Body).To(BeNil())
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: nil Request.URL"))
|
||||
})
|
||||
|
||||
It("rejects requests without a header", func() {
|
||||
req1.Header = nil
|
||||
req1.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
req.Header = nil
|
||||
req.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: nil Request.Header"))
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("rejects requests with invalid header name fields", func() {
|
||||
req1.Header.Add("foobär", "value")
|
||||
_, err := rt.RoundTrip(req1)
|
||||
req.Header.Add("foobär", "value")
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: invalid http header field name \"foobär\""))
|
||||
})
|
||||
|
||||
It("rejects requests with invalid header name values", func() {
|
||||
req1.Header.Add("foo", string([]byte{0x7}))
|
||||
_, err := rt.RoundTrip(req1)
|
||||
req.Header.Add("foo", string([]byte{0x7}))
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value"))
|
||||
})
|
||||
|
||||
It("rejects requests with an invalid request method", func() {
|
||||
req1.Method = "foobär"
|
||||
req1.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req1)
|
||||
req.Method = "foobär"
|
||||
req.Body = &mockBody{}
|
||||
_, err := rt.RoundTrip(req)
|
||||
Expect(err).To(MatchError("http3: invalid method \"foobär\""))
|
||||
Expect(req1.Body.(*mockBody).closed).To(BeTrue())
|
||||
Expect(req.Body.(*mockBody).closed).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("closing", func() {
|
||||
It("closes", func() {
|
||||
rt.clients = make(map[string]roundTripCloser)
|
||||
cl := &mockClient{}
|
||||
cl := NewMockRoundTripCloser(mockCtrl)
|
||||
cl.EXPECT().Close()
|
||||
rt.clients["foo.bar"] = cl
|
||||
err := rt.Close()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(rt.clients)).To(BeZero())
|
||||
Expect(cl.closed).To(BeTrue())
|
||||
})
|
||||
|
||||
It("closes a RoundTripper that has never been used", func() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue