use a chan instead of a context in Connection.HandshakeComplete (#3709)

This commit is contained in:
Marten Seemann 2023-03-27 00:51:56 +11:00 committed by GitHub
parent 41ddaa0262
commit a4b0265628
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 66 additions and 65 deletions

View file

@ -26,10 +26,10 @@ import (
var _ = Describe("Client", func() {
var (
cl *client
req *http.Request
origDialAddr = dialAddr
handshakeCtx context.Context // an already canceled context
cl *client
req *http.Request
origDialAddr = dialAddr
handshakeChan <-chan struct{} // a closed chan
)
BeforeEach(func() {
@ -43,9 +43,9 @@ var _ = Describe("Client", func() {
req, err = http.NewRequest("GET", "https://localhost:1337", nil)
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithCancel(context.Background())
cancel()
handshakeCtx = ctx
ch := make(chan struct{})
close(ch)
handshakeChan = ch
})
AfterEach(func() {
@ -202,7 +202,7 @@ var _ = Describe("Client", func() {
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes()
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
@ -328,7 +328,7 @@ var _ = Describe("Client", func() {
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
@ -434,7 +434,7 @@ var _ = Describe("Client", func() {
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
@ -696,7 +696,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("stream open error")
conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr)
conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError(testErr))
})
@ -721,7 +721,7 @@ var _ = Describe("Client", func() {
It("returns a response", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
@ -739,7 +739,7 @@ var _ = Describe("Client", func() {
It("doesn't close the request stream, with DontCloseRequestStream set", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
@ -758,7 +758,7 @@ var _ = Describe("Client", func() {
BeforeEach(func() {
strBuf = &bytes.Buffer{}
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().HandshakeComplete().Return(handshakeChan),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
)
body := &mockBody{}
@ -864,7 +864,7 @@ var _ = Describe("Client", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
errChan := make(chan error)
go func() {
@ -879,7 +879,7 @@ var _ = Describe("Client", func() {
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
@ -910,7 +910,7 @@ var _ = Describe("Client", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
@ -930,7 +930,7 @@ var _ = Describe("Client", func() {
Context("gzip compression", func() {
BeforeEach(func() {
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().HandshakeComplete().Return(handshakeChan)
})
It("adds the gzip header to requests", func() {