http3: correctly handle closed clients (#3684)

* http3: use a mock roundTripCloser in tests

* http3: correctly handle failed clients

Specifically,
* immediately remove a client when a request errored
* if that error was an idle error, and the client was a reused client
(from an earlier request that already completed the handshake),
re-dial the connection
This commit is contained in:
Marten Seemann 2023-01-28 00:49:52 -08:00 committed by GitHub
parent 7b2c69451e
commit 89769f409f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 336 additions and 147 deletions

View file

@ -26,7 +26,7 @@ import (
var _ = Describe("Client", func() {
var (
client *client
cl *client
req *http.Request
origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context
@ -35,10 +35,10 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
origDialAddr = dialAddr
hostname := "quic.clemente.io:1337"
var err error
client, err = newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil)
Expect(err).ToNot(HaveOccurred())
Expect(client.hostname).To(Equal(hostname))
cl = c.(*client)
Expect(cl.hostname).To(Equal(hostname))
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
@ -168,7 +168,7 @@ var _ = Describe("Client", func() {
It("refuses to do requests for the wrong host", func() {
req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTripOpt(req, RoundTripOpt{})
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)"))
})
@ -179,7 +179,7 @@ var _ = Describe("Client", func() {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return nil, testErr
}
_, err = client.RoundTripOpt(req, RoundTripOpt{})
_, err = cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
})
@ -220,7 +220,7 @@ var _ = Describe("Client", func() {
It("hijacks a bidirectional stream of unknown frame type", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return true, nil
@ -235,7 +235,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(request, RoundTripOpt{})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -243,7 +243,7 @@ var _ = Describe("Client", func() {
It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, nil
@ -259,14 +259,14 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
It("closes the connection when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, errors.New("error in hijacker")
@ -282,7 +282,7 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})
@ -291,7 +291,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error")
unknownStr := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{})
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
defer close(done)
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
@ -306,7 +306,7 @@ var _ = Describe("Client", func() {
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{})
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -348,7 +348,7 @@ var _ = Describe("Client", func() {
It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
@ -365,7 +365,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -375,7 +375,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
@ -389,7 +389,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -397,7 +397,7 @@ var _ = Describe("Client", func() {
It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
@ -415,7 +415,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
@ -467,7 +467,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
@ -492,7 +492,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead
})
@ -515,7 +515,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -539,7 +539,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorMissingSettings))
close(done)
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -563,7 +563,7 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorFrameError))
close(done)
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -586,13 +586,13 @@ var _ = Describe("Client", func() {
Expect(code).To(BeEquivalentTo(errorIDError))
close(done)
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
It("errors when the server advertises datagram support (and we enabled support for it)", func() {
client.opts.EnableDatagram = true
cl.opts.EnableDatagram = true
b := quicvarint.Append(nil, streamTypeControlStream)
b = (&settingsFrame{Datagram: true}).Append(b)
r := bytes.NewReader(b)
@ -613,7 +613,7 @@ var _ = Describe("Client", func() {
Expect(reason).To(Equal("missing QUIC Datagram support"))
close(done)
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
})
@ -705,7 +705,7 @@ var _ = Describe("Client", func() {
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
@ -721,7 +721,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
return 0, testErr
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET"))
})
@ -736,7 +736,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
@ -753,7 +753,7 @@ var _ = Describe("Client", func() {
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
@ -788,7 +788,7 @@ var _ = Describe("Client", func() {
<-done
return 0, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(strBuf)
Expect(hfs).To(HaveKeyWithValue(":method", "POST"))
@ -812,7 +812,7 @@ var _ = Describe("Client", func() {
})
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Eventually(closed).Should(BeClosed())
})
@ -831,7 +831,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors
// the response body is sent asynchronously, while already reading the response
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
req, err := client.RoundTripOpt(req, RoundTripOpt{})
req, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
Expect(req.ContentLength).To(BeEquivalentTo(1337))
Eventually(done).Should(BeClosed())
@ -844,7 +844,7 @@ var _ = Describe("Client", func() {
r := bytes.NewReader(b)
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("expected first frame to be a HEADERS frame"))
Eventually(closed).Should(BeClosed())
})
@ -856,7 +856,7 @@ var _ = Describe("Client", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes()
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)"))
Eventually(closed).Should(BeClosed())
})
@ -876,7 +876,7 @@ var _ = Describe("Client", func() {
errChan := make(chan error)
go func() {
_, err := client.RoundTripOpt(req, roundTripOpt)
_, err := cl.RoundTripOpt(req, roundTripOpt)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
@ -906,7 +906,7 @@ var _ = Describe("Client", func() {
<-canceled
return 0, errors.New("test done")
})
_, err := client.RoundTripOpt(req, roundTripOpt)
_, err := cl.RoundTripOpt(req, roundTripOpt)
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
@ -929,7 +929,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) })
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
cancel()
Eventually(done).Should(BeClosed())
@ -950,7 +950,7 @@ var _ = Describe("Client", func() {
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors
)
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTripOpt(req, RoundTripOpt{})
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
@ -989,7 +989,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
@ -1012,7 +1012,7 @@ var _ = Describe("Client", func() {
str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{})
rsp, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).ToNot(HaveOccurred())
data, err := io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())